[fastllm]多线程下动态组batch实现解析
[fastllm]多线程下动态组batch实现解析
需求分析
新版本的fastllm中添加了ForwardBatch的功能,用于处理批量推理请求,单次推理请求会被视为batch为1的批量请求,这样做似乎没什么问题。
然而在具体实践中,用户的请求往往是一个一个来的,每来一个请求都要等上一个请求完成之后才能开始推理下一个请求,一旦并发数起来,体验将及其糟糕。
所幸,stream流式输出能够在一定程度上缓解这个问题,web前端调用是一个异步线程队列,那么多个用户间的web 前端IO的时间差恰好可以给推理留出一定时间,虽然后台依然是一个token一个token地进行推理,但前端却看起来能够多用户并发使用。
但不幸的是,这种方法只在一定程度上解决问题,当用户量变多时,后台由于大量堆积的待推理的tokens,用户的体验又将变得十分糟糕。
在fastllm中,使用的是动态组batch的方案,即当A请求正在运行的token和B请求正在运行的token进行组合,合并为一个batch,在模型中并行推理,以提高模型实际运行时的吞吐。
具体实现
主要实现为两个函数,LaunchResponseTokens和FetchResponseTokens,其中LaunchResponseTokens主要根据当前输入增加一个监听token推理结果的线程,并返回context的handle;而FetchResponseTokens则根据给定的handle去队列中fetch对应的token结果。
LaunchResponseTokens的实现
LaunchResponseTokens函数可以拆成两部分去看,第一部分是mainLoopLocker.lock();到mainLoopLocker.unlock();这部分主要是创建并维护主线程。第二部分则是从dictLocker.lock();到dictLocker.unlock();这部分则是创建handle并向responseContextDict中添加初始化参数。
第二部分比较简单,可以忽略,要重点看看第一部分。第一部分是一个循环,可细分为前处理期、运行期和后处理期3部分,以model->dictLocker.lock();进行区分,是因为修改的是全局的dict。其中前处理期主要根据model config得到attentionMasks和positionIds,这两者都是以 std::vector <Data*> positionIds的类型存储的,不同的handle存储不同的参数即可,比较有意思的是inputIds,它是一个[1, all inputs len]的向量,所以需要一个seqLens来记录每个线程对应的inputs的长度。在运行期则将合成的batch输出送入到模型当中去,以并行的方式运行,不过在笔者的这版源码中,除了涉及inputs_ids部分是并行处理外,其他都是将batch进行展开计算的,也即在Attention之前的layerNorm、QKV Linear以及Attention后的FFN是多batch并行计算的,关键的Attention部分由于涉及到attentionMasks和positionIds还是需要拆batch来进行计算。不过在最新的代码中,作者已经将所有线程都处在token len为1时的状况进行了优化,这在多个长文本回复将有比较明显的加速。最后是后处理部分,这部分将各个线程对应的token取出,放入到消费者队列中,等待FetchResponseTokens的fetch。
PS: 这里mainLoop线程在启动前加了双重判断,理应来说mainLoopLocker.lock();应该是放在第一重判断和第二重判断之间的,如果在外面加lock的话,一重判断就应该是ok的。不知道是我理解错了,还是作者手滑,有待验证。
int ChatGLMModel::LaunchResponseTokens(const std::vector<int> &inputTokens,
const GenerationConfig &generationConfig) {
mainLoopLocker.lock();
if (mainLoop == nullptr) {
if (mainLoop == nullptr) {
mainLoop = new std::thread([](ChatGLMModel *model) {
while (true) {
std::vector <Data*> attentionMasks;
std::vector <Data*> positionIds;
std::vector <std::pair <Data*, Data*> > pastKeyValues;
std::vector <float> ids;
std::vector <int> seqLens;
std::vector <int> handles;
std::vector <GenerationConfig> generationConfigs;
LastTokensManager tokensManager;
model->dictLocker.lock();
for (auto &it: model->responseContextDict.dicts) {
if (it.second->isEnding) {
continue;
}
generationConfigs.push_back(it.second->generationConfig);
tokensManager.units.push_back(it.second->tokens);
handles.push_back(it.first);
for (int i = 0; i < it.second->currentTokens.size(); i++) {
ids.push_back(it.second->currentTokens[i]);
}
if (it.second->preTokens == 0) {
int seqLen = it.second->currentTokens.size();
if (model->GetVersion() == 1) {
int gmask_token_id =
model->weight.dicts.find("gmask_token_id") != model->weight.dicts.end() ?
atoi(model->weight.dicts["gmask_token_id"].c_str()) : 130001;
if (it.second->currentTokens.size() < 2 ||
it.second->currentTokens.back() != model->bos_token_id) {
ids.push_back(gmask_token_id);
ids.push_back(model->bos_token_id);
seqLen += 2;
}
} else {
if (it.second->currentTokens.size() < 2 ||
it.second->currentTokens[0] != 64790) {
ids.insert(ids.begin() + (ids.size() - it.second->currentTokens.size()), 64790);
ids.insert(ids.begin() + (ids.size() - it.second->currentTokens.size()), 64792);
seqLen += 2;
}
}
seqLens.push_back(seqLen);
std::vector<float> vmask = std::vector<float>(seqLen * seqLen, 0);
std::vector<float> vpids = std::vector<float>(seqLen * 2, 0);
for (int i = 0; i < seqLen - 1; i++) {
vmask[i * seqLen + seqLen - 1] = 1;
vpids[i] = i;
}
vpids[seqLen - 1] = seqLen - 2;
vpids[seqLen * 2 - 1] = 1;
if (model->GetVersion() == 2) {
for (int i = 0; i < seqLen; i++) {
vpids[i] = i;
for (int j = i + 1; j < seqLen; j++) {
vmask[i * seqLen + j] = 1;
}
}
}
it.second->intParams["maskIds"] = seqLen - (model->GetVersion() == 1 ? 2 : 0);
it.second->intParams["len"] = 1;
attentionMasks.push_back(new Data(DataType::FLOAT32, {seqLen, seqLen}, vmask));
positionIds.push_back(new Data(DataType::FLOAT32, {2, seqLen}, vpids));
} else {
seqLens.push_back(1);
it.second->intParams["len"]++;
attentionMasks.push_back(nullptr);
positionIds.push_back(new Data(DataType::FLOAT32, {2, 1}, {(float)it.second->intParams["maskIds"], (float)(it.second->intParams["len"])}));
if (model->GetVersion() == 2) {
it.second->intParams["maskIds"]++;
}
}
it.second->preTokens += seqLens.back();
for (int i = 0; i < model->block_cnt; i++) {
pastKeyValues.push_back(std::make_pair(&it.second->pastKeyValues[i].first,
&it.second->pastKeyValues[i].second));
}
}
if (seqLens.size() > 0) {
model->dictLocker.unlock();
#ifdef USE_CUDA
FastllmCudaClearBigBuffer();
#endif
Data inputIds = Data(DataType::FLOAT32, {1, (int) ids.size()}, ids);
std::vector<int> ret = model->ForwardBatch(seqLens.size(), inputIds, attentionMasks,
positionIds, seqLens, pastKeyValues, generationConfigs, tokensManager);
model->dictLocker.lock();
for (int i = 0; i < handles.size(); i++) {
auto &it = *model->responseContextDict.dicts.find(handles[i]);
int curRet = ret[i];
if (curRet == model->eos_token_id) {
it.second->isEnding = true;
} else {
it.second->currentTokens = std::vector<int>{curRet};
it.second->resultTokenQueue.push(curRet);
it.second->tokens.Push(curRet);
it.second->curTokens++;
if (it.second->curTokens == it.second->generationConfig.output_token_limit) {
it.second->isEnding = true;
}
}
}
}
for (int i = 0; i < attentionMasks.size(); i++) {
delete attentionMasks[i];
}
for (int i = 0; i < positionIds.size(); i++) {
delete positionIds[i];
}
model->dictLocker.unlock();
MySleep(0);
}
}, this);
}
}
mainLoopLocker.unlock();
dictLocker.lock();
int handleId = responseContextDict.CreateHandle();
ResponseContext *context = responseContextDict.GetHandle(handleId);
context->Init(this->block_cnt);
context->currentTokens = inputTokens;
context->generationConfig = generationConfig;
context->tokens = LastTokensUnit(generationConfig.last_n);
dictLocker.unlock();
return handleId;
}
FetchResponseTokens函数的实现
这部分功能就是消费者,从消费者队列中取之前生成的token即可。
实现逻辑上比较简单,从responseContextDict根据handle,找到对应的context,然后循环不断地fetch他token直到ending即可。这里有个有意思的问题,while (true)是在不断轮询队列中的token,实际上是一种简单但不太高效的写法,生产者消费者问题在系统中是一个很经典的问题。
int ChatGLMModel::FetchResponseTokens(int handleId) {
dictLocker.lock();
ResponseContext *context = responseContextDict.GetHandle(handleId);
if (context == nullptr) {
dictLocker.unlock();
return -1;
} else {
while (true) {
if (context->resultTokenQueue.size() > 0) {
int ret = context->resultTokenQueue.front();
context->resultTokenQueue.pop();
dictLocker.unlock();
return ret;
} else {
if (context->isEnding) {
responseContextDict.RemoveHandle(handleId);
dictLocker.unlock();
return -1;
}
}
dictLocker.unlock();
MySleep(0);
dictLocker.lock();
}
}
}
总结与讨论
通过构建context封装的的方式来对token进行管理,通过context字典来记录不同线程的的tokens,主线程中则对多个线程下的token和输入配置进行拼接,batch并行推理后并将结果写入到各个context中,前台则通过不同handle取对应的token,这种设计可以极大提高系统的吞吐,增强用户体验。
不过仍然有一些可讨论的点,比如forwardbatch中参数可以改为纯Data*类型的数据,不过这样的话就需要1. 区分第一次batch和后续batch,在实现上第一次运行不组batch 或者2. 进行padding,但这不是一个太好的思路。另外就是自定义手写的函数可以这么玩,但如果是onnx、Trt类似的静态图,做这样的实现可能会有一些困扰。