0605-优化器
0605-优化器
pytorch完整教程目录:https://www.cnblogs.com/nickchen121/p/14662511.html
一、优化器概述
torch 把深度学习中常用的优化方法都存储在 torch.optim
中,它的设计十分灵活,可以很方便的扩展成自定义的优化方法。
所有的优化方法都继承基类 optim.Optimizer
,并实现了自己的优化步骤,接下来我们将以最基本的优化方法——随机梯度下降法(SGD)距离说明。在这里需要重点掌握以下三个方法:
- 优化方法的基本使用方法
- 如何对模型的不同部分设置不同的学习率(lr)
- 如何调整学习率
import torch as t
from torch import nn
from torch.autograd import Variable as V
# 首先定义一个 LeNet 网络
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.features = nn.Sequential(nn.Conv2d(3, 6, 5), nn.ReLU(),
nn.MaxPool2d(2, 2), nn.Conv2d(6, 16, 5),
nn.ReLU(), nn.MaxPool2d(2, 2))
self.classifier = nn.Sequential(nn.Linear(16 * 5 * 5, 120), nn.ReLU(),
nn.Linear(120, 84), nn.ReLU(),
nn.Linear(84, 10))
def forward(self, x):
x = self.features(x)
x = x.view(-1, 16 * 5 * 5)
x = self.classifier(x)
return x
net = Net()
from torch import optim
optimizer = optim.SGD(params=net.parameters(), lr=1)
optimizer.zero_grad() # 梯度清零,等价于 net.zero_grad()
inp = V(t.randn(1, 3, 32, 32))
output = net(inp)
output.backward(output)
optimizer.step() # 优化参数
二、针对不同的网络设定不同的 lr
# 为不同子网络设置不同的学习率,在 finetune 中经常用到
# 如果对某个参数不指定学习率,就使用默认学习率
ptimizer = optim.SGD(
[
{
'params': net.features.parameters()
}, # 学习率为 1e-5
{
'params': net.classifier.parameters(),
'lr': 1e-2
}
],
lr=1e-5)
三、针对不同的层设定不同的 lr
# 只为两个全连接层设置较大的学习率,其余层的学习率较小
special_layers = nn.ModuleList([net.classifier[0], net.classifier[3]])
special_layers_params = list(map(id, special_layers.parameters())) # 得到特殊层的 id
# 筛选出不属于特殊层的层
base_params = filter(lambda p: id(p) not in special_layers_params,
net.parameters())
# 对于特殊层和非特殊层设定不同的 lr
optimizer = t.optim.SGD([{
'params': base_params
}, {
'params': special_layers.parameters(),
'lr': 0.01
}],
lr=0.001)
四、动态修改 lr
在跑代码的过程中,我们可能需要中途改变学习率的大小。在 torch 中提供了两种做法:
- 直接修改 optimizer.parm_groups 中对应的学习率(不推荐)
- 由于 optimizer 十分轻量级,开销很小,因此可以新建优化器(推荐)
如果使用第二种方法新建一个优化器,在这个过程中新建的优化器会初始化动量等状态信息,这对使用动量的优化器来说(如自带 momentum 的 sgd),可能会造成损失函数在收敛过程中震荡。
# 调整学习率,新建一个 optimizer
old_lr = 0.1
optimizer = optim.SGD([{
'params': net.features.parameters()
}, {
'params': net.classifier.parameters(),
'lr': old_lr * 0.1
}],
lr=1e-5)