kZjPBD.jpg

tensor.detach()

x = torch.tensor(2.0)
x.requires_grad_(True)
y = 2 * x
z = 5 * x

w = y + z.detach()
w.backward()

 

print(x.grad)

=> 2

 

本来应该x的梯度为7,但是detach()那一路切段了梯度的传播,导致5没有向后传递

posted @   Through_The_Night  阅读(66)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 单元测试从入门到精通
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律
点击右上角即可分享
微信分享提示