LLM并行训练2-张量并行
切分方式#
前置知识#
以下定义X的dim为(M,K), W的dim为(K, N), 平均切分z次
行式切分#

forward
先把X按列切分每个子块的dim都是 (M, K/z), W1的dim(K/z, N), 这里利用了分块矩阵乘法的性质, 把切分好的Xi scatter到对应W的卡上, 计算完成后相加结果矩阵即可拿到Y的前向结果
backward:
Y对Yi的偏导因为 Y= Y1 + Y2求导偏导是1, 可以直接省略. 只需要把L对Y的偏导广播到W1, W2各自的卡上, 他们就能各自计算对应的梯度来更新W. L对X的偏导也是两张卡各自计算后(L对Y的偏导 * Wi的转置), 最后按列concat到一起就能得到最终X的偏导
列式切分#

forward:
因为按列切分没有改变矩阵乘法的中间dim, 前向只需要concat起来两个切分后的乘法结果
backward:
这里是需要先把L对Y的导数切分后再传给各张卡, L对W的偏导计算方法和行切分一样, L对X的偏导因为对于损失L,X既参与了XW1的计算,也参与了XW2的计算, 所以需要把两张卡上对X1,X2的偏导求和. 得到最终的结果
MLP并行#
以Y = GELU(X * A) * B 为例

forward: 把参数A进行列切分, B进行行切分. 先把X广播到每张卡上, 每张卡直接算完从A->B的所有流程后, AllReduce计算结果就能得到Y
Backward: 把Grad(y)广播到各张卡上独立反向, 然后allreduce所有的grad(xi), 就能得到grad(x)
这个设计真挺巧妙的. 如果我们只用行切分或者列切分, 在两个矩阵计算的中间必然会进行一次集合通信的同步. 列切分是AllGather, 行切分是AllReduce. 然而先行后列, 中间除了节省掉集合通信的成本, 连第二次列切分的时候需要先对X做分块操作的步骤都节省了. 牛啊
MultiHeadAttention并行#

如果有两个头两张卡, 把V,Q,K权重矩阵进行列切分后. 算出来的Q1,Q2 通过concat就能得到Q, 完美的切分了数据和算力..真的感觉天然适配张量并行, 只要我们保证head数能整除卡数就能完全利用起来所有的卡.
总结#
张量并行结合了分块矩阵运算的性质, 通过合理的切分输入和参数, 再加上行列切分的合理配置. 就能节省掉很多过程中的不必要通信和冗余计算. 而且对效果无损, 看的过程中感觉好神奇.
参考#
作者:sunstrikes
出处:https://www.cnblogs.com/sunstrikes/p/18271719
版权:本作品采用「署名-非商业性使用-相同方式共享 4.0 国际」许可协议进行许可。
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 清华大学推出第四讲使用 DeepSeek + DeepResearch 让科研像聊天一样简单!
· 推荐几款开源且免费的 .NET MAUI 组件库
· 实操Deepseek接入个人知识库
· 易语言 —— 开山篇
· 【全网最全教程】使用最强DeepSeekR1+联网的火山引擎,没有生成长度限制,DeepSeek本体