Python——torch 自动求导

关于自动求导的理解

注意:

调用 backward() 的必须是标量,

若非标量,可先点乘同结构的单位张量,相当于执行一次sum(),从而变成标量。

backward.py 

复制代码
import torch
from torch.autograd import Variable

x = Variable(torch.ones(2,2),requires_grad=True)
y = x + 2
print("x =", x)
print("y =", y)

z = y * y * 3
print("z =", z)
# o = z.sum()
o = z.mean()
print("o =", o)
print("o.grad_fn =", o.grad_fn)

o.backward()
print("x.grad =", x.grad) 
复制代码

以上代码对应的数学计算过程:

 

 

 

执行   python3 backward.py 

输出为

复制代码
x = tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
y = tensor([[3., 3.],
        [3., 3.]], grad_fn=<AddBackward0>)
z = tensor([[27., 27.],
        [27., 27.]], grad_fn=<MulBackward0>)
o = tensor(27., grad_fn=<MeanBackward0>)
o.grad_fn = <MeanBackward0 object at 0x106128fd0>
x.grad = tensor([[4.5000, 4.5000],
        [4.5000, 4.5000]])
复制代码

 

posted @   会飞的斧头  阅读(414)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律
点击右上角即可分享
微信分享提示