backward的理解

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F


#反向传播
x = torch.ones(2, 2, requires_grad=True)
y = x + 2
z = y * y
out = z.mean()
# 如果没有下面这一行,x.grad=none
out.backward()
print("x.grad:{}\n".format(x.grad))

对于上面的例子,参数是x0,x1,x2,x3,假设我们把参数全部初始化为2,并且得到样本为2,2,2,2。
则本次迭代梯度计算的形式是\(out=\frac{1}{4} \sum_{i} z_{i}, z_{i}=\left(x_{i}+2\right)^{2}\)

根据链式求导法则可以求得out关于x0,x1,x2,x3的偏导数解析形式。
而backward()则是将样本值带入带入偏导数解析形式,求出数值上的偏导数。

posted on 2021-06-27 10:38  A2he  阅读(167)  评论(0编辑  收藏  举报