ZeRO-DP: 零冗余并行训练

ZeRO-DP: 零冗余并行训练

论文地址:ZeRO: Memory Optimizations Toward Training Trillion Parameter Models

自transformer及后续bert、gpt等模型发布,神经网络模型变得越来越大、训练用数据量越来越大,模型难以使用单卡进行训练,因而并行训练技术也越来越受到。在神经网络训练中,显存是制约模型训练的关键元素之一,如果gpu显存足够大,那么我们根本就不需要进行并行训练,在一张卡上完成训练即可。然而现实情况是,目前最大显存大约在80GB左右,以FP16半精度能部署约40B大小的模型,更别提训练了。而且如何减少训练时占用的显存,是一件十分值得研究的事情。

注:本文中仅介绍了ZeRO-DP对于训练时显存的优化,对于ZeRO-R部分关于激活值、缓冲区、显存碎片管理等的优化暂不涉及。

显存花到了哪里?

要研究如何减少显存占用,需要先了解显存到底花在了哪里。首先是模型参数,假设模型参数量是Φ,模型以FP16精度进行训练,则模型参数占用显存为2Φ。之后进行训练时需要保存梯度,梯度占用显存为2Φ。接下来是优化器状态,以常用的AdamW为例,需要以FP32精度保存模型参数、一阶动量、二阶动量共(4+4+4)Φ=12Φ。故在一般训练时需要显存量为2Φ+2Φ+12Φ=16Φ

哪些可以省去?

Pos: Optimizer State Partitioning

首先问题个问题:每张卡都需要保存所有的优化器状态吗?答案是:不需要。

假设我们有N张卡进行训练,每张卡分得1N份优化器状态,那么模型训练所需显存为2Φ+2Φ+12ΦN=(4+12N)Φ4Φ

由于每张卡只保留了1/N的优化器状态,所以,每张卡只能优化1NΦ的参数。在每次前向传播与反向传播完成时,首先经过一次AllReduce操作,获取全局梯度。然后进行参数更新。参数更新后,经过一次all-gather操作,将更新后的参数分发到每张卡上,完成一轮训练。整个过程单卡通信量约为3Φ

Pg: Gradient Partitioning

同样的问题:每张卡都需要保存所有的梯度吗?答案是:不需要。

假设我们有N张卡进行训练,每张卡分得1N份优化器状态和1N份梯度,那么模型训练所需显存为2Φ+14NΦ2Φ

在进行参数更新时,首先经过一次reduce-scatter操作,获取每张卡对应梯度的聚合状态。然后进行对应参数更新。参数更新完后,经过一次all-gather操作,将更新后的参数分发到每张卡上,完成一轮训练。整个过程单卡通信量约为2Φ

Pp: Parameter Partitioning

最后的问题:每张卡都需要保存所有的参数吗?答案是:不需要。

假设我们有N张卡进行训练,每张卡分得1N份优化器状态、梯度以及模型参数,那么模型训练所需显存为16NΦ

在进行前向传播时,首先进行一次All-gather操作,取回分布在别的卡上的参数,完成前向计算后,立即把不属于自己的参数抛弃,通信量Φ。在进行反向传播时,首先经过一次All-gather操作,通信量Φ,计算完梯度后,立即把参数抛弃。之后进行一次reduce-scatter操作,通信量Φ,获取每张卡对应梯度的聚合状态。然后进行参数更新,由于每张卡只维持了1NΦ的参数,所以无需在进行通信传输更新后的参数。整个过程单卡通信量约为3Φ

总结

显存 单卡通信量
朴素DP (2+2+12)Φ=16Φ 2Φ
Pos (2+2+12N)Φ4Φ 3Φ
Pos+g (2+14N)Φ2Φ 2Φ
Pos+g+p 16N 3Φ
posted @   ywycs0201  阅读(18)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 使用C#创建一个MCP客户端
· ollama系列1:轻松3步本地部署deepseek,普通电脑可用
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· 按钮权限的设计及实现
点击右上角即可分享
微信分享提示