[源码分析] Facebook如何训练超大模型---(4)
[源码分析] Facebook如何训练超大模型 --- (4)
0x00 摘要
我们在前文介绍过,微软 ZeRO 可以对一个万亿参数模型可以使用 8 路模型并行、64 路管道并行和 8 路数据并行在 4,096 个 NVIDIA A100 GPU 上进行扩展。而FSDP(Fully Sharded Data Parallel)是Facebook 深度借鉴微软ZeRO之后提出的PyTorch DDP升级版本,可以认为是对标微软 ZeRO,其本质是 parameter sharding。Parameter sharding 就是把模型参数等切分到各个GPU之上。我们会以 Google,微软和 Facebook 的论文,博客以及代码来进行学习分析。
之前文章之中我们谈到了FSDP支持混合精度训练,所以我们再来看看相关知识。
本系列其他文章如下:
[源码解析] PyTorch 分布式之 ZeroRedundancyOptimizer
[论文翻译] 分布式训练 Parameter sharding 之 ZeRO
[论文翻译] 分布式训练 Parameter Sharding 之 Google Weight Sharding
[源码分析] Facebook如何训练超大模型 --- (2)
[源码分析] Facebook如何训练超大模型 --- (3)
0x01 背景知识
1.1 单精度、双精度和半精度浮点格式的区别
我们从NVIDIA官博 What’s the Difference Between Single-, Double-, Multi- and Mixed-Precision Computing?摘录如下:
IEEE 浮点算术标准是在计算机上用二进制表示数字的通用约定。在双精度格式中,每个数字占用 64 位。单精度格式使用 32 位,而半精度只有 16 位。
在传统的科学记数法中,pi 写为 3.14 x \(10^0\)。但是计算机将这些信息以二进制形式存储为浮点数,即表示数字及其相应指数的一系列 1 和 0,在本例中为 1.1001001 x\(2^1\)。
在单精度 32 位格式中,一位用于判断数字是正数还是负数。为指数保留了八位,指数(因为它是二进制的)是 2 的某个幂。剩余的 23 位用于表示组成数字的数字,称为有效数。
相反,双精度为指数保留 11 位,为有效数保留 52 位,大大扩展了它可以表示的数字的范围和大小。半精度占据了更小的部分,只有 5 个位用于指数,10 个位用于有效数。
以下是 pi 在每个精度级别的样子
1.2 多精度和混合精度计算的区别
多精度计算意味着使用能够以不同精度进行计算的处理器——在需要时使用双精度,并依赖于应用程序的其他部分的半精度或单精度算法。
混合精度,也称为超精度,计算改为在单个操作中使用不同的精度级别,以在不牺牲精度的情况下实现计算效率。在混合精度中,计算从快速矩阵数学的半精度值开始。但是随着数字的计算,机器以更高的精度存储结果。例如,如果将两个 16 位矩阵相乘,则答案大小为 32 位。
使用这种方法,当应用程序完成计算时,累积的答案在准确度上可与在双精度算术中运行整个事情相媲美。这种技术可以将传统双精度应用程序的速度提高多达 25 倍,同时减少运行它们所需的内存、运行时间和功耗。它可用于 AI 和模拟 HPC 工作负载。
1.3 混合精度
采用FP16的优势如下:
- 内存占用更少。如果采用FP16,则模型占用是FP32的一半,这样可以训练更大的模型,使用更大的batch size,通信量更少。
- 计算更快。FP16的加速优化可以加快训练和推理的计算。
- 另外,随着NVIDIA Tensor Core 的普及,FP6计算也越来越快。
FP16的问题主要是其表示范围比FP32狭窄,所以会带来两个问题:溢出错误 和 舍入误差。因此,百度和NVIDIA联手在论文之中提出了一些技术。
- 保留一份FP32格式的权重主备份。
- 使用loss scale来避免梯度过小。
- 使用FP16计算但是用FP32进行累加。
比如,对于主备份,论文之中图例如下:
1.4 训练过程
上面介绍的三种技术对于训练过程是一个良好的补充,我们从NVIDIA官方文档 https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html 摘录训练过程具体如下。
- 维护一份FP32的参数主副本。
- 对于每次迭代:
- 制作权重的FP16副本。
- 使用FP16权重和激活进行向前传播。
- 将得到的损失乘以比例因子S。
- 使用FP16权重,激活和它们的梯度进行后向传播。
- 将权重梯度乘以1/S。
- 完成权重更新(包括gradient clipping等)。
一个更稳健的方法是动态地选择损失比例因子。其基本思想是以一个大的比例因子开始,然后在每次训练迭代中重新考虑它。如果在选定的迭代次数N中没有发生溢出,则增加比例因子。如果发生溢出,则跳过权重更新,降低比例因子。我们发现,只要不频繁地跳过更新,训练计划就不必调整,就可以达到与FP32训练相同的精度。
请注意,N有效地限制了我们可以溢出和跳过更新的频率。缩放因子的更新率可以通过选择增加/减少的乘数以及N(增加前的非溢出迭代次数)来调整。
动态损失缩放方法对应了了以下训练流程:
- 在FP32中保持一份权重的主副本。
- 将S初始化为一个大的数值。
- 对于每个迭代
- 制作一个权重的FP16副本。
- 前向传播(FP16权重和激活)。
- 用比例因子S乘以所得的损失。
- 后向传播(FP16权重、激活和它们的梯度)。
- 如果权重梯度中存在Inf或NaN。
- 减少S。
- 跳过权重更新,进入下一次迭代。
- 将权重梯度与1/S相乘。
- 完成权重更新(包括梯度剪裁等)。
- 如果在过去的N次迭代中没有出现Inf或NaN,则增加S。
图片来自 https://developer.nvidia.com/automatic-mixed-precision
0x02 PyTorch
2.1 英伟达算力
英伟达的Volta及Turing架构GPU在使用FP16计算时的特点如下:
- FP16的内存带宽和存储需求相比FP32来说可以降低一半,这样开发者在相同的硬件条件下可以使用更大更复杂的模型和更大的batch size。
- 英伟达Volta和Turing架构GPU提供了Tensor Cores技术。Tensor Cores的FP16计算吞吐量是FP32的8倍。
因此,在相同的超参数下,使用半精度浮点(FP16)和单精度(FP32)浮点的混合精度训练就可以达到与使用纯单精度(FP32)训练相同的准确率,而且模型训练速度可以大大加速。
2.2 Torch.cuda.amp
PyTorch之中的混合精度主要是依赖 torch.cuda.amp 这个库,这就说明这个功能是依赖于CUDA的。
前面分析提到了为何要混合计算的原因,这是因为:
- 在某些场合下对精度损失不敏感,局部精度损失对最终训练效果影响非常微弱,并且能利用Tensor Cores进行加速,此时FP16有优势。
- 某些场合下对精度损失特别敏感,此时FP32有优势。
PyTorch 之中,与混合精度相关的张量是torch.FloatTensor和torch.HalfTensor,这两个混合起来使用就是混合精度了。而框架会根据实际需要来自动(有时需要手动调整)调整一个张量的类型,在torch.FloatTensor和torch.HalfTensor 之中切换,这就是automatic mixed precision(AMP)的来由。
2.2.1 使用
具体使用上,PyTorch 就是使用了autocast + GradScaler。我们从 https://github.com/NVIDIA/DeepLearningExamples 官方例子找出来看看。
GradScaler 的作用是放大loss,防止梯度underflow,但这只是在反向传播传递梯度时候使用,更新权重时候还需要把梯度缩放回原来的大小。
autocast上下文应该只是包括前向传播和loss计算,因为反向传播会自动使用前向传播同样的类型。
from torch.cuda.amp import autocast as autocast
def do_train(
model,
data_loader,
optimizer,
scheduler,
checkpointer,
device,
checkpoint_period,
arguments,
use_amp,
cfg,
dllogger,
per_iter_end_callback_fn=None,
):
# 模型默认的是torch.FloatTensor
max_iter = len(data_loader)
start_iter = arguments["iteration"]
model.train()
if use_amp:
# 构建GradScaler
scaler = torch.cuda.amp.GradScaler(init_scale=8192.0)
for iteration, (images, targets, _) in enumerate(data_loader, start_iter):
iteration = iteration + 1
images = images.to(device)
targets = [target.to(device) for target in targets]
if use_amp:
with torch.cuda.amp.autocast(): # 前向传播开启autocast
loss_dict = model(images, targets)
else:
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
# reduce losses over all GPUs for logging purposes
loss_dict_reduced = reduce_loss_dict(loss_dict)
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
# Note: If mixed precision is not used, this ends up doing nothing
# Otherwise apply loss scaling for mixed-precision recipe
if use_amp:
scaler.scale(losses).backward() # 放大梯度
else:
losses.backward()
def _take_step():
if use_amp:
scaler.step(optimizer) # 在方法内部,如果梯度正常,则更新权重,否则忽略此次更新
scaler.update() # 是否需要增大scaler
else:
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if not cfg.SOLVER.ACCUMULATE_GRAD:
_take_step()
else:
if (iteration + 1) % cfg.SOLVER.ACCUMULATE_STEPS == 0:
for param in model.parameters():
if param.grad is not None:
param.grad.data.div_(cfg.SOLVER.ACCUMULATE_STEPS)
_take_step()
2.2.2 多Model,losses和优化器
如果你的网络有多个损失,你必须在每个网络之中单独调用scaler.scale。如果网络有多个优化器,你可以在它们之中任意一个单独调用scaler.unscale,并且你必须在每个之中都单独调用scaler.step。
但是,在迭代之中所有优化器都完成step操作之后,才可以调用 scaler.update,并且只能调用一次。
每个优化器检查梯度是否为 infs/NaN,并独立决定是否跳过该步骤。这可能会导致一个优化器跳过该步骤,而另一个则没有。由于很少发生跳步(每几百次迭代可能才有一次),这不应妨碍收敛。
scaler = torch.cuda.amp.GradScaler()
for epoch in epochs:
for input, target in data:
optimizer0.zero_grad()
optimizer1.zero_grad()
with autocast():
output0 = model0(input)
output1 = model1(input)
loss0 = loss_fn(2 * output0 + 3 * output1, target)
loss1 = loss_fn(3 * output0 - 5 * output1, target)
# (retain_graph here is unrelated to amp, it's present because in this
# example, both backward() calls share some sections of graph.)
scaler.scale(loss0).backward(retain_graph=True)
scaler.scale(loss1).backward()
# You can choose which optimizers receive explicit unscaling, if you
# want to inspect or modify the gradients of the params they own.
scaler.unscale_(optimizer0)
scaler.step(optimizer0)
scaler.step(optimizer1)
scaler.update()
2.2.3 分布式
torch.nn.DataParallel 在每个设备上产生一个线程来运行正向传递。autocast state 是线程本地的,因此以下内容将不起作用:
model = MyModel()
dp_model = nn.DataParallel(model)
# Sets autocast in the main thread
with autocast():
# dp_model's internal threads won't autocast. The main thread's autocast state has no effect.
output = dp_model(input)
# loss_fn still autocasts, but it's too late...
loss = loss_fn(output)
修复很简单。在MyModel.forward
之中使用 autocast。
MyModel(nn.Module):
...
@autocast()
def forward(self, input):
...
# Alternatively
MyModel(nn.Module):
...
def forward(self, input):
with autocast():
...
以下代码在dp_model
的线程(执行forward
)和主线程(执行loss_fn
)中自动转换:
model = MyModel()
dp_model = nn.DataParallel(model)
with autocast():
output = dp_model(input)
loss = loss_fn(output)
torch.nn.parallel.DistributedDataParallel 的文档建议每个进程使用一个 GPU 以获得最佳性能。在这种情况下,DistributedDataParallel
不会在内部产生线程,因此autocast
和GradScaler
的使用不受影响。
或者在 forward 方法内部使用with autocast(),这样可以保证autocast在进程内部生效,比如。
def _forward(self, sample):
loss = None
oom = 0
try:
if sample is not None:
with amp.autocast(enabled=self.args.amp):
# calculate loss and sample size
logits, _ = self.model(**sample['net_input'])
target = sample['target']
probs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
loss = self.criterion(probs, target)
except RuntimeError as e:
if 'out of memory' in str(e):
print('| WARNING: ran out of memory in worker {}, skipping batch'.format(
self.args.distributed_rank), force=True)
oom = 1
loss = None
else:
raise e
return loss, oom
0x03 FSDP 使用
torch.cuda.amp.autocast 与FSDP完全兼容,但是用户需要把mixed_precision
设置为True,具体示例代码如下:
offload_model = OffloadModel(
model=model,
device=torch.device("cuda"),
offload_device=torch.device("cpu"),
num_slices=3,
checkpoint_activation=True,
num_microbatches=1,
)
torch.cuda.set_device(0)
device = torch.device("cuda")
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001)
# To train 1 epoch.
offload_model.train()
for batch_inputs, batch_outputs in dataloader:
batch_inputs, batch_outputs = batch_inputs.to("cuda"), batch_outputs.to("cuda")
start = time.time_ns()
optimizer.zero_grad()
inputs = batch_inputs.reshape(-1, num_inputs * num_inputs)
with torch.cuda.amp.autocast(): # 设定使用 amp
output = model(inputs)
loss = criterion(output, target=batch_outputs)
loss.backward()
optimizer.step()
我们接下来看看FSDP的相关源码,
3.1 成员变量
因为涉及了 CPU offload 和分区等因素,所以FSDP不能简单使用amp,需要和 CPU offload 和分区结合起来看,比如FP16参数也需要分区和offload,因为amp不会自动分区和offload,所以FSDP需要把这部分活承担过来,显式的进行部分切换工作。
前文代码提到了一些与混合精度训练相关的成员变量,这里就是把32,16位参数分别进行分区操作,也会相应做offload操作。
_fp32_shard
:full precision的单个参数分片(通常为fp32,但这取决于用户传入的模型数据类型)。这可以在CPU或GPU上进行,具体取决于cpu_offload
的值。_fp16_shard
:在混合精度模式下,我们在计算设备上维护一个降低精度(通常是FP16)的参数分片,用于在前向/后向传递中执行计算。这就是``_fp16_shard,如果
mixed_precision为
True`,这将是fp16中参数的单个shard,用于all-gather。_full_param_padded
:在向前和向后传播中用于计算的全部权重(被填充为可被world_size
均匀整除)。这将原地调整大小,并仅在需要时具体化(通过all-gather)。
代码之中也需要做相应设置,如果我们计划将FP32/FP16参数保留在CPU上,那么固定内存允许我们以后在将FP32/FP16参数碎片移动到计算设备时使用非阻塞传输。分区操作是 FP32,FP16 统一处理的。
3.2 Scaler
在 Scaler 方法,FSDP也推出了有特色的 ShardedGradScaler。PyTorch自动混合精度的实际使用情况将取决于OSS是与DDP还是与ShardedDDP一起使用。
- 如果OSS与DDP一起使用,那么就可以使用正常的PyTorch GradScaler,不需要做任何改变。
- 如果OSS与ShardedDDP一起使用(为了获得梯度分片),那么可以使用一个非常类似的流程,但它需要一个感知梯度的GradScaler。它可以在
fairscale.optim.grad_scaler
中使用。
在这两种情况下,Autocast都可以照常使用,并且损失将以同样的方式被缩放和处理。
我们看看ShardedGradScaler代码,会发现其特色在于使用 dist.all_reduce 在 ranks 之间进行规约。
import torch
from torch.cuda.amp import GradScaler as TorchGradScaler
import torch.distributed as dist
from torch.optim import Optimizer
from .oss import OSS
class GradScaler(TorchGradScaler):
def _unscale_grads_(
self, optimizer: Optimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool
) -> Dict[torch.device, torch.Tensor]:
return super()._unscale_grads_(optimizer, inv_scale, found_inf, True)
class ShardedGradScaler(TorchGradScaler):
"""
A shard-aware :class:`GradScaler<torch.cuda.amp.GradScaler>`, to be used in conjunction with
:class:`OSS` and :class:`ShardedOptimizer`.
Interface and usecases are not changed, more explanations can be found in the corresponding pytorch
documentation https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler
"""
def __init__(
self,
init_scale: float = 2.0 ** 16,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000,
enabled: bool = True,
process_group: Any = dist.group.WORLD,
) -> None:
super().__init__(
init_scale=init_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
enabled=enabled,
)
self.display_warning = True
self.group = process_group
def unscale_(self, optimizer: Optimizer) -> None:
# Could be a mistake, this scaler is supposed to work with ZeroRedundancyOptimizer only
if self.display_warning and not isinstance(optimizer, OSS):
logging.warning(
"ShardedGradScaler is to be used in combination with a sharded optimizer, this could not be checked"
)
self.display_warning = False # Only warn once
# Call the upstream unscale_ method which will only act on this rank's gradients
super().unscale_(optimizer)
# Synchronize the detected inf across the ranks
optimizer_state = self._per_optimizer_states[id(optimizer)]
last_handle = None
# 使用了 AllReduce
for v in optimizer_state["found_inf_per_device"].values():
last_handle = dist.all_reduce(v, async_op=True, group=self.group)
# Make sure that the calls are done before moving out.
# The calls are executed in sequence, waiting for the last one is enough
if last_handle is not None:
last_handle.wait()
3.3 初始化
我们接着看看 offload 和混合精度如何使用。在初始化方法 _init_param_attributes 之中,也有操作会为移动到CPU做准备,比如放到锁页内存之中,也会为混合精度创建张量,比如_fp16_shard。
@torch.no_grad()
def _init_param_attributes(self, p: Parameter) -> None:
if hasattr(p, "_fp32_shard"):
return
# A single shard of the parameters in full precision.
p._fp32_shard = p.data
if self.mixed_precision:
# 为移动到CPU做准备
if self.move_params_to_cpu:
# If we plan to keep the FP32 parameters on CPU, then pinning
# memory allows us to later use non-blocking transfers when moving
# the FP32 param shard to compute_device.
p._fp32_shard = p._fp32_shard.pin_memory()
p.data = p._fp32_shard
# 在混合精度模式下,我们在计算设备上维护一个降低精度(通常是FP16)的参数分片,
# 用于在前向/后向传递中执行计算。
# In mixed precision mode, we maintain a reduced precision
# (typically FP16) parameter shard on compute_device for performing
# the computation in the forward/backward pass. We resize the
# storage to size 0 at init (here) and re-materialize (by copying
# from _fp32_shard) as needed.
p._fp16_shard = torch.zeros_like(p._fp32_shard, device=self.compute_device, dtype=self.compute_dtype)
free_storage_(p._fp16_shard)
else:
p._fp16_shard = None # use _fp32_shard
# We also maintain a full-sized parameter of type self.compute_dtype
# (FP16 for mixed_precision or FP32 otherwise). We resize the
# storage to size 0 at init (here) and only materialize as needed. The
# storage may contain padding elements so that it is evenly divisible by
# world_size, although these padding elements will be removed before the
# relevant computation.
if p._is_sharded:
p._full_param_padded = torch.zeros(
p.data.numel() * self.world_size, device=self.compute_device, dtype=self.compute_dtype
)
free_storage_(p._full_param_padded)
# 为移动到CPU做准备
if self.move_grads_to_cpu:
# We can optionally move the grad shard to CPU during the backward
# pass. In this case, it's important to pre-allocate the CPU grad
# shard in pinned memory so that we can do a non-blocking transfer.
p._cpu_grad = torch.zeros_like(p.data, device="cpu").pin_memory()
逻辑如下:
3.4 重建
我们以 _rebuild_full_params 为例。因为前面分析过,这里只是把相关代码摘录,代码会依据各种配置进行切换,比如如果指定了强制全精度,则还需要从FP16转换为FP32,然后再进行all-gather。
@torch.no_grad()
def _rebuild_full_params(self, force_full_precision: bool = False) -> Optional[List[Tuple[torch.Tensor, bool]]]:
output_tensors: List[Tuple[torch.Tensor, bool]] = []
def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:
"""
Helper function to update p.data pointer.
"""
if custom_output_tensor is not None:
# 省略
elif not p._is_sharded:
if self.mixed_precision and not force_full_precision: # 切换到 FP16
p.data = p._fp16_shard
output_tensors.append((p.data, True))
else:
# Here p.data == p._fp32_shard, so it's not safe to free.
output_tensors.append((p.data, False))
else:
# 省略
# Trim any padding and reshape to match original size.
p.data = p.data[: p._orig_size.numel()].view(p._orig_size)
with torch.cuda.stream(self._streams["all_gather"]):
if self.mixed_precision and not force_full_precision:
self._cast_fp32_param_shards_to_fp16() # 从fp32切换到fp16
for p in self.params:
if not p._is_sharded: # e.g., when world_size == 1
update_p_data()
else:
# If self.move_params_to_cpu and force_full_precision, we need to cast
# the FP32 CPU param to CUDA for the all-gather.
# 拷贝到GPU
p_data = p.data.to(p._full_param_padded.device, non_blocking=True)
p_size = p._full_param_padded.size()
if self.mixed_precision and force_full_precision:
# Allocate fresh tensor in full precision since we are in
# mixed precision and full precision rebuild is asked.
# 在全精度中分配新的张量,因为我们处于混合精度中,需要进行全精度重建。
output_tensor = p_data.new_zeros(p_size)
else:
if p._full_param_padded.storage().size() != p_size.numel():
alloc_storage_(p._full_param_padded, size=p_size)
output_tensor = p._full_param_padded
# Fill output_tensor with (p.data for each shard in self.world_size)
dist.all_gather(chunks, p_data, group=self.process_group) # 简化版本代码
if self.mixed_precision and not force_full_precision:
self._free_fp16_param_shard([p]) # 释放内存
# 省略
逻辑如下:
3.5 cast操作
可以从 _cast_fp32_param_shards_to_fp16 之中看到如何做转换操作。
@torch.no_grad()
def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = None) -> None:
"""Cast FP32 param shard to FP16 for a list of params."""
if params is None:
params = self.params
with torch.cuda.stream(self._streams["fp32_to_fp16"]):
for p in params:
alloc_storage_(p._fp16_shard, size=p._fp32_shard.size())
p._fp16_shard.copy_(
# If cpu_offload is True, this will be non-blocking because
# _fp32_shard is pinned, otherwise it's a no-op.
p._fp32_shard.to(p._fp16_shard.device, non_blocking=True)
)
p.data = p._fp16_shard
torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"])
3.6 _post_reduction_hook
在 _post_backward_hook
之中会设置 callback_fn
,就是在 reduce-scatter
之后调用_post_reduction_hook
,可以理解,就是做完这个操作之后,可以把梯度移动到CPU了。
callback_fn = functools.partial(self._post_reduction_hook, param)
具体代码如下,和offload相关的是把梯度移动到CPU的操作,和混合精度相关的是把梯度转换为参数张量的类型。
def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
"""Hook to call on each param after the reduce-scatter."""
param.grad.data = reduced_grad
if self.gradient_postdivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
param.grad.data.div_(self.gradient_postdivide_factor)
# Cast grad to param's dtype (typically FP32). Note: we do this
# before the move_grads_to_cpu step so that this entire hook remains
# non-blocking. The downside is a bit more D2H transfer in that case.
if self.mixed_precision:
orig_param_grad_data = param.grad.data # 把梯度进行转换,一般来说是切换回 FP32
param.grad.data = param.grad.data.to(dtype=param.data.dtype)
# Don't let this memory get reused until after the transfer.
orig_param_grad_data.record_stream(torch.cuda.current_stream())
if hasattr(param, "_saved_grad_shard") and param._saved_grad_shard is not None:
param.grad.data += param._saved_grad_shard
delattr(param, "_saved_grad_shard")
# Optionally move gradients to CPU, typically used if one is running
# the optimizer on the CPU.
if self.move_grads_to_cpu: # 把梯度移动到CPU
param._cpu_grad.copy_(param.grad.data, non_blocking=True)
# Don't let this memory get reused until after the transfer.
param.grad.data.record_stream(torch.cuda.current_stream())
param.grad.data = param._cpu_grad
至此,混合精度分析完毕,我们下一篇看看 FSDP 如何使用 Activation recomputation,敬请期待。
0xFF 参考
ZeRO-Offload: Democratizing Billion-Scale Model Training
https://www.deepspeed.ai/tutorials/zero-offload/