Inferllm源码解析

Inferllm源码解析

文件结构

  1. application: 放置几个不同模型的参数配置和后处理
  2. include: 包含抽象model类的头文件
  3. src/core: 核心组件,包括tensor、算法等基础算子的抽象和KV文件系统的实现
  4. src/graph: 包含了几种LLM模型的具体实现
  5. src/kern: 包含了不同硬件下的算子实现
  6. src: 剩余一些其他公共函数实现

继承与组合关系

chat/alpaca/chatglm是外置的application程序,通过该入口设置模型随机种子、线程数、token等inferllm::ModelConfig模型参数
利用这些模型参数构建inferllm::Model的shared_ptr,经过load、init、decode_iter等操作进行编解码,设置fix_word函数对结果后处理

关键inferllm::Model类中包含了实际执行类inferllm::ModelImp,而inferllm::ModelImp则包含了inferllms::device的unique_ptr和inferllms::Graph的shared_ptr以及Vocab的shared_ptr,在实际计算时,则通过inferllm::ModelImp中的graph进行execute,得到token,将token进行解码并返回。

alpaca
chat/baichuan/llama
chatglm
-- inferllm::mdoel
-- inferllm::ModelImp
-- inferllms::Tensor
-- inferllms::device(CPUDevice, GPUdevice)
-- inferllm::ThreadPool
-- inferllm::Kernel
-- inferllm::KernelID
-- inferllm::LlmParams
-- inferllm::Graph(llamaGraph, chatglmGraph, baichuanGraph)
-- inferllm::OprModuleBase()
-- inferllm::OpBase(LayerNorm, Embedding, SoftMax, Elemwise, MatMul, MatMulLast, LlamaAttention, GlmAttention, DiagMask)
-- inferllm::Kernel
-- inferllm::UserConfig
-- inferllm::WorkSpace
-- inferllm::Tensor
-- inferllm:: Vocab
-- inferllm::ModelConfig
-- compt_type
-- nr_thread
-- nr_ctx
-- device_id

核心类属性分析

model_imp类

所有model的抽象接口,也是所有model的基类,当其他模型运行时,使用该基类的shared_ptr对象,借助多态的方法实现模型的参数加载load函数,模型填充prefill函数以及token的encode和decode操作,

load: 加载模型文件,将权重加载到graph中,并将最终输出logist重置为对应的vocab_len
prefill: 类似于warmup函数,将token填入到模型中
decode: 第一次运行的token,放于网络中运行并解码
decode_iter:非第一次的模型运行tokens,取top_k个返回
sample_and_update:加入惩罚因子对输出进行惩罚,并选出top_p个作为下一个token
tokenize: 将文本编码为对应的input_ids
decode_summary: 将编解码效率和速度总结输出

OprModuleBase类

是所有opr操作的基础类,有添加opr函数,get_all_weights函数,获取inputs函数,获取output函数,name名字,device设备等函数。

execute:有pre_execute、execute、end_execute对op的输入输出进行前处理和后处理
get_workspace_in_byte:得到当前Module中的所有op中最大占用空间

LlamaFFNModule类

包含了2次matmul乘积silu激活函数并进行残差操作,然后与w3进行matmul得到最终输出结果。

GlmFFNModule类

包含了一次matmul乘积gelu激活函数,然后与w2相乘得到最后输出。

HeadModule类

包含了一次layer_norm,对输入进行norm和matmul得到最终输出。

EmbdModule

Embedding模块,通过input_ids找到对应的embedding编码。

OpBase类

所有Op基础类,包含了执行预处理操作,执行操作,后执行操作、set_name、set_outputs、set_weights等操作。

LayerNorm类、Embedding类、SoftMax类、Elemwise类、MatMul类、MatMulLast类、LlamaAttention类、GlmAttention类、DiagMask类

调用两种不同的kernel,RmsNormFloat和NormFloat,对输入数据进行归一化处理。
调用EmbeddingGetFloatFloat kernel进行使用不同的device计算。
调用softmax类在不同的device上进行操作

Kernel实现

llm_elemwise_broadcast_dim0_src1_compute_float_add_gpu、llm_elemwise_compute_float_scale_gpu

实现elemwise的乘法和加法,每个线程计算一次乘法或者加法,并且当第二个数的维度小于第一个的维度时,采用dim0上广播expand。
例如:
[[1, 2, 3,], [4, 5, 6]] * [[7, 8, 9]]
这时,如果乘完计算123之后,会循环计算,第二个矩阵就在dim0上进行广播扩充。

llm_elemwise_broadcast_dim0_src1_compute_float

计算当前机器的blocks,设置不同的blocks,找到对应不同的乘加算法。

ApplyFunction

函数模板,根据不同的函数function去实现不同的计算逻辑

LaunchKernel

带安全检查和block参数配置的应用函数

llm_softmax_compute_float_gpu

先找到最大值,val = exp(v-max_v), sum += val_all, softmax(val_i / sum)
并行策略:每次计算一行上的softmax,多线程并行计算多行,线程数为行数

llm_norm_compute_float_gpu

得到每一行中的方差scale,v[i] *= scale
并行策略:每次计算一个seq_len上的norm,并行线程数为seq_len

llm_embedding_get_float_float

将embedding的头指针拷贝到cuda中,并将每一行对应的数据拷贝到cuda中

dequantize_row_q4_0_reference_gpu

__restrict用法
反量化操作,加了unroll操作,每32个数计算一次,其中反量化算子下高4位为第一个数,低4位为第二个数,并在计算时加入了scale因子。
并行策略:在embedding量化中并行线程数为seq_len

SiluFunctor、GeluFunctor、AddFunctor、MulFunctor

cpu计算函数,一些常见的函数

llm_rms_norm_compute_float_gpu

得到每一行中的方差scale,v[i] *= scale
并行策略:每次计算一个seq_len上的norm,并行线程数为seq_len

llm_rope_compute_float_gpu

计算rotate_position_embedding,每次计算rotate_scale, 将position分为两半,前后两部分分别计算x0 * cos_scale - x1 * sin_scale
并行策略:每次计算每个seq_len下每个head每个rotate下的旋转位置编码,并行线程数seqlen * head * (n_rot / 2)次

llm_matmul_compute_float_float_gpu

矩阵乘积运算,每次只进行一行一列的计算
并行策略:并行线程数M×N

llm_matmul_compute_int4_float_step1_gpu

分两步进行计算,第一步计算两个张量的scale尺度,保存在d中,第二步将对应位置上的x,y取出来,int4乘加后并保存在float中,其中val = x_int4_sum * y_int4_sum * d1 * d2

llm_scale_diag_mask_inf_float_gpu、llm_diag_mask_inf_float_gpu

将矩阵对角线以上的mask设置为无限大,否则乘以scale
并行策略:并行线程数为head * seqlen * (n_past + seqlen)

llm_matmul_compute_with_head_stride_float、llm_head_batched_matmul_compute_float_gpu

用于计算多头、多batch矩阵乘积,并行数量上比普通矩阵乘法多了一个head num

posted @ 2023-08-06 10:32  wildkid1024  阅读(197)  评论(0编辑  收藏  举报