Python——torch 自动求导
关于自动求导的理解
注意:
调用 backward() 的必须是标量,
若非标量,可先点乘同结构的单位张量,相当于执行一次sum(),从而变成标量。
backward.py
import torch from torch.autograd import Variable x = Variable(torch.ones(2,2),requires_grad=True) y = x + 2 print("x =", x) print("y =", y) z = y * y * 3 print("z =", z) # o = z.sum() o = z.mean() print("o =", o) print("o.grad_fn =", o.grad_fn) o.backward() print("x.grad =", x.grad)
以上代码对应的数学计算过程:
执行 python3 backward.py
输出为
x = tensor([[1., 1.], [1., 1.]], requires_grad=True) y = tensor([[3., 3.], [3., 3.]], grad_fn=<AddBackward0>) z = tensor([[27., 27.], [27., 27.]], grad_fn=<MulBackward0>) o = tensor(27., grad_fn=<MeanBackward0>) o.grad_fn = <MeanBackward0 object at 0x106128fd0> x.grad = tensor([[4.5000, 4.5000], [4.5000, 4.5000]])