pytorch---损失函数和优化器
一、损失函数
损失函数可以当作是nn的某一个特殊层,也是nn.Module的子类。但是实际中。然而在实际使用中通常将这些loss function专门提取出来,和主模型互相独立。
score=t.randn(3,2) //batch_size=3,类别是2.
label=t.Tensor([1,0,1].long()) //注意label必须得是longTensor
criterion=nn.CrossEntropyLoss() //CroosEntropyLoss是常用的二分类损失函数
loss=criterion(score,label)
二、优化器
所有的优化方法都封装在torch.optim里面,他的设计很灵活,可以扩展为自定义的优化方法。
所有的优化方法都是继承了基类optim.Optimizer。并实现了自己的优化步骤。
关于优化器需要掌握:
1、优化方法基本使用方法。
2、如何为不同的层设置不同的学习率。
3、如何调整学习率。
1、优化方法基本使用
import optim
ptimizer=optim.SGD(params=net.parameters(),lr=0.1)
optimizer.zero_grad() //梯度清零
output=net(input)
output.backward(output)
optimizer.step()
2、如何为不同的层设置不同的学习率
法1
# 为不同子网络设置不同的学习率,在finetune中经常用到
# 如果对某个参数不指定学习率,就使用最外层的默认学习率
optimizer =optim.SGD([
{'params': net.features.parameters()}, # 学习率为1e-5
{'params': net.classifier.parameters(), 'lr': 1e-2}
], lr=1e-5)
法2
# 只为两个全连接层设置较大的学习率,其余层的学习率较小
special_layers = nn.ModuleList([net.classifier[0], net.classifier[3]])
special_layers_params = list(map(id, special_layers.parameters()))
base_params = filter(lambda p: id(p) not in special_layers_params,
net.parameters())
optimizer = t.optim.SGD([
{'params': base_params},
{'params': special_layers.parameters(), 'lr': 0.01}
], lr=0.001 )
optimizer
3、如何调整学习率
法1(比较推荐):新建一个optimizer。对于使用动量的优化器(如Adam),会丢失动量等状态信息,可能会造成损失函数的收敛出现震荡等情况
old_lr = 0.1
optimizer1 =optim.SGD([
{'params': net.features.parameters()},
{'params': net.classifier.parameters(), 'lr': old_lr*0.1}
], lr=1e-5)
optimizer1
法2:一种是修改optimizer.param_groups中对应的学习率
# 方法2: 调整学习率, 手动decay, 保存动量
for param_group in optimizer.param_groups:
param_group['lr'] *= 0.1 # 学习率为之前的0.1倍
optimizer
nn中的每一个layer在functional中基本都有与之对应的函数。为此怎么判断何时用nn.Module,何时用functional里的函数呢。
input = t.randn(2, 3)
model = nn.Linear(3, 4)
output1 = model(input)
output2 = nn.functional.linear(input, model.weight, model.bias)
output1 == output2
b = nn.functional.relu(input)
b2 = nn.ReLU()(input)
b == b2
可见,Module是一个类,会自动提取可学习的参数。 但是functional是一个纯函数。所以当某layer是有看学习参数的话用Module实现(卷积、全连接),否则就是functional(激活、池化)。但是,注意Dropout虽然可学习参数但是因为在训练和测试的时候他的行为不一样,所以仍然放在Module,通使用nn.Module
对象能够通过model.eval
操作加以区分。