【PyTorch官方教程中文版学习笔记03】损失函数&更新权重

1. 损失函数

    在深度学习中,损失反映模型最后预测结果与实际真值之间的差距,可以用来分析训练过程的好坏、模型是否收敛等,例如均方损失、交叉熵损失等。

    PyTorch中,损失函数可以看做是网络的某一层而放到模型定义中,但在实际使用时更偏向于作为功能函数而放到前向传播过程中。

    损失函数举例:均方误差(mean squared error)、交叉熵误差(cross entropy error)等使用时参考手册torch.nn — PyTorch 1.11.0 documentation

  实例:nn.MSELoss  均方误差

    

复制代码
#接上篇神经网络
out = net(input)
net.zero_grad()
out.backward(torch.randn(1, 10))
target = torch.randn(10) # a dummy target, for example
target = target.view(1, -1) # make it the same shape as output
criterion = nn.MSELoss()
loss = criterion(out, target)
print(loss)
#输出结果
tensor(1.7306, grad_fn=<MseLossBackward0>)
“因为input和target的是随机torch阵,所以loss结果不固定”
复制代码

 

2. 反向传播

     在神经网络的学习中,寻找最优参数(权重和偏置)时,要寻找使损失函数的值尽可能小的参数。

     为了找到使损失函数的值尽可能小的地方,需要计算参数的导数(确切地讲是梯度),然后以这个导数为指引,逐步更新参数的值。

     数值微分可以计算神经网络的权重参数的梯度(严格来说,是损失函数关于权重参数的梯度),但是计算上比较费时间。

     误差反向传播法则是一个能够高效计算权重参数的梯度的方法。

 

  •  为了实现反向传播损失,我们所有需要做的事情仅仅是使用 loss.backward()。你需要清空现存的梯度,要不然梯度将会和现存的梯度累计到一起。
net.zero_grad() # zeroes the gradient buffers of all parameters
print('conv1.bias.grad before backward')
print(net.conv1.bias.grad)
loss.backward()
print('conv1.bias.grad after backward')
print(net.conv1.bias.grad)
#输出
conv1.bias.grad before backward None #清空了现存梯度 conv1.bias.grad after backward tensor([
8.4274e-05, 2.4798e-03, 1.1413e-03, 2.4606e-03, 1.6488e-02, -7.1301e-03])
  • 优化器

    利用反向传播,优化器应运而生。优化器可以更新参数即网络中的权重,进行模型优化、加速收敛。

    常用的优化器算法SGD, Nesterov-SGD, Adam,RMSProp等。优化算法的设计可以作为课题深层研究,目前只需要会使用现成算法就可以。

    算法在算法包torch.optim。

    torch.optim — PyTorch 1.11.0 documentation

optimizer = optim.SGD(net.parameters(), lr=0.01)# create your optimizer
optimizer.zero_grad()# zero the gradient buffers
output = net(input)
loss = criterion(output, target)
loss.backward()
optimizer.step()# Does the update

    优化通常要经过好几轮的for循环,训练模型使得模型整体loss减小。

 

    

至此,我们在前篇
1.定义一个神经网络
2.处理输入以及调用反向传播

的基础上继续补充了:

3.计算损失值
4.更新网络中的权重

 完成了一个典型的神经网络训练过程。

 

 参考文献:

《Pytorch官方教程中文版 》

《 深度学习之Pytorch物体检测实战》

《深度学习入门:基于python的理论与实现》

posted @   只想毕业的菜狗  阅读(858)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律
点击右上角即可分享
微信分享提示