Python——torch 自动求导
关于自动求导的理解
注意:
调用 backward() 的必须是标量,
若非标量,可先点乘同结构的单位张量,相当于执行一次sum(),从而变成标量。
backward.py
import torch from torch.autograd import Variable x = Variable(torch.ones(2,2),requires_grad=True) y = x + 2 print("x =", x) print("y =", y) z = y * y * 3 print("z =", z) # o = z.sum() o = z.mean() print("o =", o) print("o.grad_fn =", o.grad_fn) o.backward() print("x.grad =", x.grad)
以上代码对应的数学计算过程:
执行 python3 backward.py
输出为
x = tensor([[1., 1.], [1., 1.]], requires_grad=True) y = tensor([[3., 3.], [3., 3.]], grad_fn=<AddBackward0>) z = tensor([[27., 27.], [27., 27.]], grad_fn=<MulBackward0>) o = tensor(27., grad_fn=<MeanBackward0>) o.grad_fn = <MeanBackward0 object at 0x106128fd0> x.grad = tensor([[4.5000, 4.5000], [4.5000, 4.5000]])
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律