LLM并行训练2-张量并行

切分方式#

前置知识#

矩阵乘法求导

Y=f(AB)=f(C)

YA=YCBT

YB=ATYC

以下定义X的dim为(M,K), W的dim为(K, N), 平均切分z次

行式切分#

image-20240627150212684

forward

Y=X1W1+X2W2

X=concat(X1,X2,axis=1)

W=concat(W1,W2,axis=0)

先把X按列切分每个子块的dim都是 (M, K/z), W1的dim(K/z, N), 这里利用了分块矩阵乘法的性质, 把切分好的Xi scatter到对应W的卡上, 计算完成后相加结果矩阵即可拿到Y的前向结果

backward:

LWi=LYYYiYiWi

Y对Yi的偏导因为 Y= Y1 + Y2求导偏导是1, 可以直接省略. 只需要把L对Y的偏导广播到W1, W2各自的卡上, 他们就能各自计算对应的梯度来更新W. L对X的偏导也是两张卡各自计算后(L对Y的偏导 * Wi的转置), 最后按列concat到一起就能得到最终X的偏导

列式切分#

image-20240627151744984

forward:

Y=concat(X1W1,X2W2,axis=1)

因为按列切分没有改变矩阵乘法的中间dim, 前向只需要concat起来两个切分后的乘法结果

backward:

LWi=LYYiWi

LX=LX1+YiX2

这里是需要先把L对Y的导数切分后再传给各张卡, L对W的偏导计算方法和行切分一样, L对X的偏导因为对于损失L,X既参与了XW1的计算,也参与了XW2的计算, 所以需要把两张卡上对X1,X2的偏导求和. 得到最终的结果

MLP并行#

以Y = GELU(X * A) * B 为例

image-20240627165655511

forward: 把参数A进行列切分, B进行行切分. 先把X广播到每张卡上, 每张卡直接算完从A->B的所有流程后, AllReduce计算结果就能得到Y

Backward: 把Grad(y)广播到各张卡上独立反向, 然后allreduce所有的grad(xi), 就能得到grad(x)

这个设计真挺巧妙的. 如果我们只用行切分或者列切分, 在两个矩阵计算的中间必然会进行一次集合通信的同步. 列切分是AllGather, 行切分是AllReduce. 然而先行后列, 中间除了节省掉集合通信的成本, 连第二次列切分的时候需要先对X做分块操作的步骤都节省了. 牛啊

MultiHeadAttention并行#

image-20240627170918832

如果有两个头两张卡, 把V,Q,K权重矩阵进行列切分后. 算出来的Q1,Q2 通过concat就能得到Q, 完美的切分了数据和算力..真的感觉天然适配张量并行, 只要我们保证head数能整除卡数就能完全利用起来所有的卡.

总结#

张量并行结合了分块矩阵运算的性质, 通过合理的切分输入和参数, 再加上行列切分的合理配置. 就能节省掉很多过程中的不必要通信和冗余计算. 而且对效果无损, 看的过程中感觉好神奇.

参考#

https://zhuanlan.zhihu.com/p/622212228

作者:sunstrikes

出处:https://www.cnblogs.com/sunstrikes/p/18271719

版权:本作品采用「署名-非商业性使用-相同方式共享 4.0 国际」许可协议进行许可。

posted @   SunStriKE  阅读(362)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 清华大学推出第四讲使用 DeepSeek + DeepResearch 让科研像聊天一样简单!
· 推荐几款开源且免费的 .NET MAUI 组件库
· 实操Deepseek接入个人知识库
· 易语言 —— 开山篇
· 【全网最全教程】使用最强DeepSeekR1+联网的火山引擎,没有生成长度限制,DeepSeek本体
点击右上角即可分享
微信分享提示
more_horiz
keyboard_arrow_up light_mode palette
选择主题
menu