10.损失函数及其梯度(均方差梯度[使用线性函数y=w*x+b作为激活函数])

1.MSE(均方差)梯度

(1)均方差MSE

(2)MSE求梯度

 

 

 

 【注】例如网络形式为线性感知机:ƒ(x)=w*x+b这里只是举例,具体用什么样的函数需要根据实际的网络结构。

对w求导则是:Δƒw(w)/Δw

对b求导则是:Δƒb(b)/Δb

(3)均方差在pytorch中如何求梯度

(3.1.1)torch.autograd.grad(loss,[w1,w2...........])

 

 

 【注】pytorch中mse_loss的自动微分:

F.mse_loss(label,pred) pred的为线性感知机中的w*x+b,label为x。

torch.autograd.grad(mse,para)para为线性感知机中的w和b参数。其中第一个参数必须为维度为1长度为1的tensor。

【注】只有浮点数型数据才能计算梯度,故上图中会出现23和24行下面的错误。requires_grad_()可以对tensor类型的数据进行更新,使其可以进行梯度运算。

(3.1.2)loss.backward()

 

 (3.1.3)pytorch中损失函数求梯度的两种方法总结

 

 [注]两种方式返回值的形式不同:

第一种为【w1 grad,w2 grad】

第二种为w1.grad或者w2.grad等。

[注]可以对tensor类型的数据进行.norm查看tensor的norm,也可以对梯度信息进行.norm。

posted @ 2021-07-23 19:30  收购阿里巴巴  阅读(796)  评论(0编辑  收藏  举报