[fastllm]多线程下动态组batch实现解析
[fastllm]多线程下动态组batch实现解析
需求分析
新版本的fastllm中添加了ForwardBatch的功能,用于处理批量推理请求,单次推理请求会被视为batch为1的批量请求,这样做似乎没什么问题。
然而在具体实践中,用户的请求往往是一个一个来的,每来一个请求都要等上一个请求完成之后才能开始推理下一个请求,一旦并发数起来,体验将及其糟糕。
所幸,stream流式输出能够在一定程度上缓解这个问题,web前端调用是一个异步线程队列,那么多个用户间的web 前端IO的时间差恰好可以给推理留出一定时间,虽然后台依然是一个token一个token地进行推理,但前端却看起来能够多用户并发使用。
但不幸的是,这种方法只在一定程度上解决问题,当用户量变多时,后台由于大量堆积的待推理的tokens,用户的体验又将变得十分糟糕。
在fastllm中,使用的是动态组batch的方案,即当A请求正在运行的token和B请求正在运行的token进行组合,合并为一个batch,在模型中并行推理,以提高模型实际运行时的吞吐。
具体实现
主要实现为两个函数,LaunchResponseTokens和FetchResponseTokens,其中LaunchResponseTokens主要根据当前输入增加一个监听token推理结果的线程,并返回context的handle;而FetchResponseTokens则根据给定的handle去队列中fetch对应的token结果。
LaunchResponseTokens的实现
LaunchResponseTokens函数可以拆成两部分去看,第一部分是mainLoopLocker.lock();到mainLoopLocker.unlock();这部分主要是创建并维护主线程。第二部分则是从dictLocker.lock();到dictLocker.unlock();这部分则是创建handle并向responseContextDict中添加初始化参数。
第二部分比较简单,可以忽略,要重点看看第一部分。第一部分是一个循环,可细分为前处理期、运行期和后处理期3部分,以model->dictLocker.lock();进行区分,是因为修改的是全局的dict。其中前处理期主要根据model config得到attentionMasks和positionIds,这两者都是以 std::vector <Data*> positionIds的类型存储的,不同的handle存储不同的参数即可,比较有意思的是inputIds,它是一个[1, all inputs len]的向量,所以需要一个seqLens来记录每个线程对应的inputs的长度。在运行期则将合成的batch输出送入到模型当中去,以并行的方式运行,不过在笔者的这版源码中,除了涉及inputs_ids部分是并行处理外,其他都是将batch进行展开计算的,也即在Attention之前的layerNorm、QKV Linear以及Attention后的FFN是多batch并行计算的,关键的Attention部分由于涉及到attentionMasks和positionIds还是需要拆batch来进行计算。不过在最新的代码中,作者已经将所有线程都处在token len为1时的状况进行了优化,这在多个长文本回复将有比较明显的加速。最后是后处理部分,这部分将各个线程对应的token取出,放入到消费者队列中,等待FetchResponseTokens的fetch。
PS: 这里mainLoop线程在启动前加了双重判断,理应来说mainLoopLocker.lock();应该是放在第一重判断和第二重判断之间的,如果在外面加lock的话,一重判断就应该是ok的。不知道是我理解错了,还是作者手滑,有待验证。
int ChatGLMModel::LaunchResponseTokens(const std::vector<int> &inputTokens, const GenerationConfig &generationConfig) { mainLoopLocker.lock(); if (mainLoop == nullptr) { if (mainLoop == nullptr) { mainLoop = new std::thread([](ChatGLMModel *model) { while (true) { std::vector <Data*> attentionMasks; std::vector <Data*> positionIds; std::vector <std::pair <Data*, Data*> > pastKeyValues; std::vector <float> ids; std::vector <int> seqLens; std::vector <int> handles; std::vector <GenerationConfig> generationConfigs; LastTokensManager tokensManager; model->dictLocker.lock(); for (auto &it: model->responseContextDict.dicts) { if (it.second->isEnding) { continue; } generationConfigs.push_back(it.second->generationConfig); tokensManager.units.push_back(it.second->tokens); handles.push_back(it.first); for (int i = 0; i < it.second->currentTokens.size(); i++) { ids.push_back(it.second->currentTokens[i]); } if (it.second->preTokens == 0) { int seqLen = it.second->currentTokens.size(); if (model->GetVersion() == 1) { int gmask_token_id = model->weight.dicts.find("gmask_token_id") != model->weight.dicts.end() ? atoi(model->weight.dicts["gmask_token_id"].c_str()) : 130001; if (it.second->currentTokens.size() < 2 || it.second->currentTokens.back() != model->bos_token_id) { ids.push_back(gmask_token_id); ids.push_back(model->bos_token_id); seqLen += 2; } } else { if (it.second->currentTokens.size() < 2 || it.second->currentTokens[0] != 64790) { ids.insert(ids.begin() + (ids.size() - it.second->currentTokens.size()), 64790); ids.insert(ids.begin() + (ids.size() - it.second->currentTokens.size()), 64792); seqLen += 2; } } seqLens.push_back(seqLen); std::vector<float> vmask = std::vector<float>(seqLen * seqLen, 0); std::vector<float> vpids = std::vector<float>(seqLen * 2, 0); for (int i = 0; i < seqLen - 1; i++) { vmask[i * seqLen + seqLen - 1] = 1; vpids[i] = i; } vpids[seqLen - 1] = seqLen - 2; vpids[seqLen * 2 - 1] = 1; if (model->GetVersion() == 2) { for (int i = 0; i < seqLen; i++) { vpids[i] = i; for (int j = i + 1; j < seqLen; j++) { vmask[i * seqLen + j] = 1; } } } it.second->intParams["maskIds"] = seqLen - (model->GetVersion() == 1 ? 2 : 0); it.second->intParams["len"] = 1; attentionMasks.push_back(new Data(DataType::FLOAT32, {seqLen, seqLen}, vmask)); positionIds.push_back(new Data(DataType::FLOAT32, {2, seqLen}, vpids)); } else { seqLens.push_back(1); it.second->intParams["len"]++; attentionMasks.push_back(nullptr); positionIds.push_back(new Data(DataType::FLOAT32, {2, 1}, {(float)it.second->intParams["maskIds"], (float)(it.second->intParams["len"])})); if (model->GetVersion() == 2) { it.second->intParams["maskIds"]++; } } it.second->preTokens += seqLens.back(); for (int i = 0; i < model->block_cnt; i++) { pastKeyValues.push_back(std::make_pair(&it.second->pastKeyValues[i].first, &it.second->pastKeyValues[i].second)); } } if (seqLens.size() > 0) { model->dictLocker.unlock(); #ifdef USE_CUDA FastllmCudaClearBigBuffer(); #endif Data inputIds = Data(DataType::FLOAT32, {1, (int) ids.size()}, ids); std::vector<int> ret = model->ForwardBatch(seqLens.size(), inputIds, attentionMasks, positionIds, seqLens, pastKeyValues, generationConfigs, tokensManager); model->dictLocker.lock(); for (int i = 0; i < handles.size(); i++) { auto &it = *model->responseContextDict.dicts.find(handles[i]); int curRet = ret[i]; if (curRet == model->eos_token_id) { it.second->isEnding = true; } else { it.second->currentTokens = std::vector<int>{curRet}; it.second->resultTokenQueue.push(curRet); it.second->tokens.Push(curRet); it.second->curTokens++; if (it.second->curTokens == it.second->generationConfig.output_token_limit) { it.second->isEnding = true; } } } } for (int i = 0; i < attentionMasks.size(); i++) { delete attentionMasks[i]; } for (int i = 0; i < positionIds.size(); i++) { delete positionIds[i]; } model->dictLocker.unlock(); MySleep(0); } }, this); } } mainLoopLocker.unlock(); dictLocker.lock(); int handleId = responseContextDict.CreateHandle(); ResponseContext *context = responseContextDict.GetHandle(handleId); context->Init(this->block_cnt); context->currentTokens = inputTokens; context->generationConfig = generationConfig; context->tokens = LastTokensUnit(generationConfig.last_n); dictLocker.unlock(); return handleId; }
FetchResponseTokens函数的实现
这部分功能就是消费者,从消费者队列中取之前生成的token即可。
实现逻辑上比较简单,从responseContextDict根据handle,找到对应的context,然后循环不断地fetch他token直到ending即可。这里有个有意思的问题,while (true)是在不断轮询队列中的token,实际上是一种简单但不太高效的写法,生产者消费者问题在系统中是一个很经典的问题。
int ChatGLMModel::FetchResponseTokens(int handleId) { dictLocker.lock(); ResponseContext *context = responseContextDict.GetHandle(handleId); if (context == nullptr) { dictLocker.unlock(); return -1; } else { while (true) { if (context->resultTokenQueue.size() > 0) { int ret = context->resultTokenQueue.front(); context->resultTokenQueue.pop(); dictLocker.unlock(); return ret; } else { if (context->isEnding) { responseContextDict.RemoveHandle(handleId); dictLocker.unlock(); return -1; } } dictLocker.unlock(); MySleep(0); dictLocker.lock(); } } }
总结与讨论
通过构建context封装的的方式来对token进行管理,通过context字典来记录不同线程的的tokens,主线程中则对多个线程下的token和输入配置进行拼接,batch并行推理后并将结果写入到各个context中,前台则通过不同handle取对应的token,这种设计可以极大提高系统的吞吐,增强用户体验。
不过仍然有一些可讨论的点,比如forwardbatch中参数可以改为纯Data*类型的数据,不过这样的话就需要1. 区分第一次batch和后续batch,在实现上第一次运行不组batch 或者2. 进行padding,但这不是一个太好的思路。另外就是自定义手写的函数可以这么玩,但如果是onnx、Trt类似的静态图,做这样的实现可能会有一些困扰。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· Manus的开源复刻OpenManus初探
· AI 智能体引爆开源社区「GitHub 热点速览」
· 三行代码完成国际化适配,妙~啊~
· .NET Core 中如何实现缓存的预热?