LLM并行训练2-张量并行
切分方式
前置知识
以下定义X的dim为(M,K), W的dim为(K, N), 平均切分z次
行式切分
forward
先把X按列切分每个子块的dim都是 (M, K/z), W1的dim(K/z, N), 这里利用了分块矩阵乘法的性质, 把切分好的Xi scatter到对应W的卡上, 计算完成后相加结果矩阵即可拿到Y的前向结果
backward:
Y对Yi的偏导因为 Y= Y1 + Y2求导偏导是1, 可以直接省略. 只需要把L对Y的偏导广播到W1, W2各自的卡上, 他们就能各自计算对应的梯度来更新W. L对X的偏导也是两张卡各自计算后(L对Y的偏导 * Wi的转置), 最后按列concat到一起就能得到最终X的偏导
列式切分
forward:
因为按列切分没有改变矩阵乘法的中间dim, 前向只需要concat起来两个切分后的乘法结果
backward:
这里是需要先把L对Y的导数切分后再传给各张卡, L对W的偏导计算方法和行切分一样, L对X的偏导因为对于损失L,X既参与了XW1的计算,也参与了XW2的计算, 所以需要把两张卡上对X1,X2的偏导求和. 得到最终的结果
MLP并行
以Y = GELU(X * A) * B 为例
forward: 把参数A进行列切分, B进行行切分. 先把X广播到每张卡上, 每张卡直接算完从A->B的所有流程后, AllReduce计算结果就能得到Y
Backward: 把Grad(y)广播到各张卡上独立反向, 然后allreduce所有的grad(xi), 就能得到grad(x)
这个设计真挺巧妙的. 如果我们只用行切分或者列切分, 在两个矩阵计算的中间必然会进行一次集合通信的同步. 列切分是AllGather, 行切分是AllReduce. 然而先行后列, 中间除了节省掉集合通信的成本, 连第二次列切分的时候需要先对X做分块操作的步骤都节省了. 牛啊
MultiHeadAttention并行
如果有两个头两张卡, 把V,Q,K权重矩阵进行列切分后. 算出来的Q1,Q2 通过concat就能得到Q, 完美的切分了数据和算力..真的感觉天然适配张量并行, 只要我们保证head数能整除卡数就能完全利用起来所有的卡.
总结
张量并行结合了分块矩阵运算的性质, 通过合理的切分输入和参数, 再加上行列切分的合理配置. 就能节省掉很多过程中的不必要通信和冗余计算. 而且对效果无损, 看的过程中感觉好神奇.