Pytorch checkpoint
checkpoint一种用时间换空间的策略
torch.utils.checkpoint.
checkpoint
(function, *args, **kwargs)
为模型或模型的一部分设置Checkpoint 。
检查点用计算换内存(节省内存)。 检查点部分并不保存中间激活值,而是在反向传播时重新计算它们。 它可以应用于模型的任何部分。
具体而言,在前向传递中,function将以torch.no_grad()的方式运行,即不存储中间激活值。 相反,前向传递将保存输入元组和function参数。 在反向传播时,检索保存的输入和function参数,然后再次对函数进行正向计算,现在跟踪中间激活值,然后使用这些激活值计算梯度。
(也即,检查点部分在前向计算时不存储中间量,等反向传播需要计算梯度时重新计算这些中间量)
WARNING
- 检查点不适用于torch.autograd.grad(),而仅适用于torch.autograd.backward()。
- 如果反向传播过程中的函数调用与前向传播过程中的函数调用有任何的不同,例如由于某个全局变量,则检查点版本将不相等,并且很遗憾,它无法被检测到。
Parameters
function:
描述模型或模型的一部分在前向传播中运行什么。它还应该知道如何处理作为元组传递的输入。例如,在LSTM中,如果用户通过(activation, hidden),则函数应正确使用第一个输入作为activation,第二个输入作为hidden。
reserve_rng_state(bool, optional, default=True)
在每个检查点期间省略存储和恢复RNG状态。
args
包含函数输入的元组(输入)
Returns
在*args(输入)上运行function得到的输出
torch.utils.checkpoint.
checkpoint_sequential
(functions, segments, *inputs, **kwargs)
用于在sequential model中设置检查点的辅助函数。
sequential model按顺序执行模块/函数列表。因此,我们可以将这种模型划分为不同的段,并在每个段上检查点。除最后一个段外的所有段都将以torch.no_grad()方式运行,即不存储中间激活。将保存每个检查点段的输入部分,以便在反向传播中重新运行该段。
See checkpoint()
on how checkpointing works.
Parameters
functions:
A torch.nn.Sequential
或 依次运行的模块或函数(包含模型)的列表。
segments:
在模型中创建的块数
*inputs:
作为函数输入的张量元组
reserve_rng_state(bool, optional, default=True)
在每个检查点期间省略存储和恢复RNG状态。
Returns
在* input上顺序运行函数得到的输出
Example
>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)
在DenseNet中为了解决GPU内存占用大的问题,就采用了这种策略缓解显存占用大的问题。
下面是denselayer的细节:
1 class _DenseLayer(nn.Sequential): # bottleneck + conv 2 def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False): 3 super(_DenseLayer, self).__init__() 4 self.add_module("norm1", nn.BatchNorm2d(num_input_features)) 5 self.add_module("relu1", nn.ReLU(inplace=True)) 6 self.add_module("conv1", nn.Conv2d(num_input_features, bn_size*growth_rate, 7 kernel_size=1, stride=1, bias=False)) 8 9 self.add_module("norm2", nn.BatchNorm2d(bn_size*growth_rate)) 10 self.add_module("relu2", nn.ReLU(inplace=True)) 11 self.add_module("conv2", nn.Conv2d(bn_size*growth_rate, growth_rate, 12 kernel_size=3, stride=1, padding=1, bias=False)) 13 14 self.drop_rate = drop_rate 15 self.memory_efficient = memory_efficient 16 17 def forward(self, *prev_features): 18 bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1) 19 if self.memory_efficient and any(prev_feature.requires_grad for prev_feature in prev_features): 20 bottleneck_output = cp.checkpoint(bn_function, *prev_features) 21 else: 22 bottleneck_output = bn_function(*prev_features) 23 new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) 24 if self.drop_rate > 0: 25 new_features = F.dropout(new_features, self.drop_rate, training=self.training) 26 return new_features