PyTorch grad_fn的作用以及RepeatBackward, SliceBackward示例

来源   https://www.cnblogs.com/picassooo/p/13757403.html

 

 

PyTorch grad_fn的作用以及RepeatBackward, SliceBackward示例

 

变量.grad_fn表明该变量是怎么来的,用于指导反向传播。例如loss = a+b,则loss.gard_fn为<AddBackward0 at 0x7f2c90393748>,表明loss是由相加得来的,这个grad_fn可指导怎么求a和b的导数。

程序示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
 
w1 = torch.tensor(2.0, requires_grad=True)
= torch.tensor([[1.2.], [3.4.]], requires_grad=True)
tmp = a[0, :]
tmp.retain_grad()   # tmp是非叶子张量,需用.retain_grad()方法保留导数,否则导数将会在反向传播完成之后被释放掉
= tmp.repeat([31])
b.retain_grad()
loss = (b * w1).mean()
loss.backward()
 
print(b.grad_fn)    # 输出: <RepeatBackward object at 0x7f2c903a10f0>
print(b.grad)       # 输出: tensor([[0.3333, 0.3333],
                    #               [0.3333, 0.3333],
                    #               [0.3333, 0.3333]])
 
print(tmp.grad_fn)    # 输出:<SliceBackward object at 0x7f2c90393f60>
print(tmp.grad)       # 输出:tensor([1., 1.])
 
 
print(a.grad)     # 输出:tensor([[1., 1.],
                  #              [0., 0.]])

手动推导:

手动推导的结果和程序的结果是一致的。

 

 

 

posted @   wodepingzi  阅读(697)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· 无需6万激活码!GitHub神秘组织3小时极速复刻Manus,手把手教你使用OpenManus搭建本
点击右上角即可分享
微信分享提示