LLM并行训练3-数据并行
前置知识#
混合精度训练#
#
在参数存储时采取fp32, 开始进行fp/bp时转成fp16运算, 拿到fp16梯度后再转回fp32更新参数.
ZeRO对显存占用的估算:
- 模型状态: Weights(fp16)、grad(fp16) 和 MasterWeights(fp32 模型参数备份),momentum(fp32)和variance(fp32)。假设模型参数量
,则共需要 字节存储, - 剩余状态: 除了模型状态之外的显存占用,包括激活值(activation)、各种临时缓冲区(buffer)以及无法使用的显存碎片(fragmentation)
Adam#

在adam optimizer的计算状态除了参数, 还有一个
AllToAll通信原语#

allToall类似于矩阵转置. 相当于我们需要先把每个节点里的数据按照他们要传递给哪个节点排好序, 然后根据切分好的顺序推给对应的节点. 可以看到如果每个节点的数据量是
ZeRO#
在传统的训练方法里, 每张卡里存储一份完整的模型状态, 完成bp后allReduce grad,再更新每张卡里的副本. 这样子有N张卡就会多出
- ZeRO1: 只保留一份
MasterWeights+momentum+variance
. - ZeRO2: 在ZeRO1的基础上去除了grad的冗余
- ZeRO3: 在ZeRO2的基础上去掉了weights的冗余

训练流程#
以ZeRO3为例. 主要分为5步, 假设使用了4张卡进行训练:
- 每张卡上存1/4的W, OS和grad. 每张卡训练自己分配到的batch.
- fp时, AllGather所有卡上的W,取到全量的W(fp16)进行fp, 完成后只保留自己需要维护的1/4 W, 其他显存释放回池
- bp时, AllGather所有卡上的W进行bp, 完成后再抛弃其他卡维护的W
- 完成bp后, ReduceScatter所有卡的G, 从其他卡上取到需要需要更新的梯度增量, 然后释放不是自己维护的G.
- 使用自己维护的OS和G来更新W, 不需要通信.


通信量分析#
定义单卡数据量为
传统DP: bp完成后需要对梯度进行一次AllReduce, 一共
ZeRO1: 只舍弃了OS, bp时需要AllReduce G(Scatter+Gather 共
ZeRO2: 舍弃了OS和G, bp时AllGather G(
ZeRO3: 上面训练过程分析过, 共需要2次Gather和1次Scatter, 一共需要
可以看到ZeRO在通信量只增加了1.5倍的情况下, 显存降了60倍. 效果非常显著
ZeRO++#
ZeRO存在的问题是会在GPU之间产生大量数据传输开销,降低了训练效率. 主要有两种情况:
-
全局batch size较小,而 GPU数量多,这导致每个 GPU 上batch size较小,需要频繁通信
-
在低端集群上进行训练,其中跨节点网络带宽有限,导致高通信延迟。
ZeRO++主要采用了3部分优化: 权重量化 (qwZ), 分层分割存储 (hpZ), 梯度量化 (qgZ). 对比ZeRO通信量减少了4倍, 主要的难点都在减小量化带来的训练误差上
权重量化#
def _quantize_int8(self, tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
q_range = 2**self.config['num_bits'] - 1
min_value = tensor.amin(dim=self.config['group_dim'] + 1, keepdim=True)
max_value = tensor.amax(dim=self.config['group_dim'] + 1, keepdim=True)
scale = q_range / (max_value - min_value)
tensor = tensor.sub_(min_value).mul_(scale)
tensor = tensor_round(tensor_clamp(tensor, 0, q_range)).to(torch.uint8) #对称式量化
return tensor, scale, min_value
量化代码在deepspeedcsrc/quantization/quantize.cu cached_quantization
这个kernel里.
如果采用全局fp16->int8的量化会导致极大误差. deepspeed采用了分区量化的方法, 把参数分为固定大小的block后, 先根据这个block的max/min计算出scale(量化系数), 在把这个参数传入量化函数里. 另外在通信的时候应该也需要每个block对应的系数传给接收节点用于反量化.
通过这种方式在通信量减半的同时还能保证精度, 很nice的思路.
分层分割存储#

之前ZeRO的W切分方法是根据卡数均分. 在fp/bp之前进行AllGather拉取, 后来发现在机器间进行Gather通信是比较严重的瓶颈. 所以最后W的切分变成了每个节点内存储全量的W, 节点内根据卡数进行切片. 避免跨节点经过网卡的通信, 通过增加显存使用的方式解决通信瓶颈.
显存消耗: ZeRO3的单卡显存消耗为
梯度量化#
如果直接在之前zero RingAllReduce的通信方式上加量化和反量化, 如下图左, 可以看到需要节点个数次量化/反量化. 而每次量化都是有损的, 这样会导致无法接受的训练误差. 为了解决这个问题zero++使用了一次量化->AllToAll通信->一次反量化的操作. 而因为直接进行AllToAll通信量从

具体图示流程可以参考Deepspeed的blog动态图, 文字版步骤:
-
节点内的卡间张量切片重排. 主要是因为alltoall切分成了两步, 如果不重排如下图左. 最后顺序会变错位, 然后进行参数量化
-
节点内alltoall通信后反量化.先把卡内能合并的梯度加起来. 这里反量化主要是为了减小梯度累加的精度损失
-
再次量化后, 节点间进行allToAll
-
拿到通信结果, 反量化后再次reduce. 得到最终的梯度.
这里要进行两次alltoall的原因主要是, 第一次卡间alltoall之后梯度累加可以减少卡数倍的通信规模. 实际deepspeed在实现的时候还把重分片和量化kernel进行了fuse, 进一步优化性能
还有下图的方法, 在通信当前层的时候, 通过多流异步量化下一层要通信的数据. 避免同步等待的浪费

参考#
zero: https://arxiv.org/pdf/1910.02054
混合精度训练: https://arxiv.org/pdf/1710.03740
zero++: https://arxiv.org/abs/2306.10209
Deepspeed blog: https://github.com/microsoft/DeepSpeed/blob/master/blogs/zeropp/chinese/README.md
作者:sunstrikes
出处:https://www.cnblogs.com/sunstrikes/p/18274445
版权:本作品采用「署名-非商业性使用-相同方式共享 4.0 国际」许可协议进行许可。
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 清华大学推出第四讲使用 DeepSeek + DeepResearch 让科研像聊天一样简单!
· 推荐几款开源且免费的 .NET MAUI 组件库
· 实操Deepseek接入个人知识库
· 易语言 —— 开山篇
· 【全网最全教程】使用最强DeepSeekR1+联网的火山引擎,没有生成长度限制,DeepSeek本体