Python——torch 自动求导

关于自动求导的理解

注意:

调用 backward() 的必须是标量,

若非标量,可先点乘同结构的单位张量,相当于执行一次sum(),从而变成标量。

backward.py 

import torch
from torch.autograd import Variable

x = Variable(torch.ones(2,2),requires_grad=True)
y = x + 2
print("x =", x)
print("y =", y)

z = y * y * 3
print("z =", z)
# o = z.sum()
o = z.mean()
print("o =", o)
print("o.grad_fn =", o.grad_fn)

o.backward()
print("x.grad =", x.grad) 

以上代码对应的数学计算过程:

 

 

 

执行   python3 backward.py 

输出为

x = tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
y = tensor([[3., 3.],
        [3., 3.]], grad_fn=<AddBackward0>)
z = tensor([[27., 27.],
        [27., 27.]], grad_fn=<MulBackward0>)
o = tensor(27., grad_fn=<MeanBackward0>)
o.grad_fn = <MeanBackward0 object at 0x106128fd0>
x.grad = tensor([[4.5000, 4.5000],
        [4.5000, 4.5000]])

 

posted @ 2022-02-16 19:11  会飞的斧头  阅读(400)  评论(0编辑  收藏  举报