梯度削减方法
1.torch.nn.utils.clip_grad_norm_
https://pytorch.org/docs/master/generated/torch.nn.utils.clip_grad_norm_.html
根据梯度的范数值进行削减,
https://stackoverflow.com/questions/54716377/how-to-do-gradient-clipping-in-pytorch
optimizer.zero_grad() loss, hidden = model(data, hidden, targets) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) optimizer.step()
2.clip_grad_value_
https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/clip_grad.py,这两个的代码实现都在这个文件下。
如果是根据范数削减,就会先求范数:
如果是根据值削减,那就简单了:
直接对存在梯度值的参数进行削减。使用方法:
nn.utils.clip_grad_value_(net.linear.weight, clip_value=1.1)
3.例子
https://zhuanlan.zhihu.com/p/99953668,这个非常好。
给出的梯度=120的时候,就会导致梯度爆炸,后续更新参数就会出现问题。
下面这个实验是想看一下net.named_parameters()和net.parameters()访问所有参数有什么区别
import torch import torch.nn as nn class LinearNet(nn.Module): def __init__(self, features_in=1, features_out=1): super().__init__() self.linear = nn.Linear(features_in, features_in) self.linear2 = nn.Linear(features_in, features_out) self._init_weight() def forward(self, x): return self.linear(x) def _init_weight(self): nn.init.constant_(self.linear.weight, val=1) nn.init.constant_(self.linear.bias, val=0) net = LinearNet() for tag, value in net.named_parameters(): print(tag) print(value) print('\n') for p in net.parameters(): print(p) print(p.grad)
输出:
linear.weight #参数名 Parameter containing: #参数的值 tensor([[1.]], requires_grad=True) linear.bias Parameter containing: tensor([0.], requires_grad=True) linear2.weight Parameter containing: tensor([[-0.4493]], requires_grad=True) linear2.bias Parameter containing: tensor([0.9038], requires_grad=True) Parameter containing:#直接就是只有值 tensor([[1.]], requires_grad=True) None#通过.grad就可以访问梯度,但由于现在还没求,所以是None Parameter containing: tensor([0.], requires_grad=True) None Parameter containing: tensor([[-0.4493]], requires_grad=True) None Parameter containing: tensor([0.9038], requires_grad=True) None
4.爆炸/消失原因
知乎 https://zhuanlan.zhihu.com/p/48776056
梯度爆炸:当梯度g>1且计算次数n比较大时, 有可能会变得非常大;
梯度消失:当g<1且计算次数n比较大时,有可能非常小,基本相当于为0。
5.过程
先梯度截断,再用优化器更新梯度参数。