pytorch反向传播两次,梯度相加,retain_graph=True

pytorch是动态图计算机制,也就是说,每次正向传播时,pytorch会搭建一个计算图,loss.backward()之后,这个计算图的缓存会被释放掉,下一次正向传播时,pytorch会重新搭建一个计算图,如此循环。

在默认情况下,PyTorch每一次搭建的计算图只允许一次反向传播,如果要进行两次反向传播,则需要在第一次反向传播时设置retain_graph=True,即 loss.backwad(retain_graph=True) ,这样做可以保留动态计算图,在第二次反向传播时,将自动和第一次的梯度相加。

示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
 
input_ = torch.tensor([[1., 2.], [3., 4.]], requires_grad=False)
w1 = torch.tensor(2.0, requires_grad=True)
w2 = torch.tensor(3.0, requires_grad=True)
 
l1 = input_ * w1
l2 = l1 + w2
loss1 = l2.mean()
loss1.backward(retain_graph=True)
 
print(w1.grad)  # 输出:tensor(2.5)
print(w2.grad)  # 输出:tensor(1.)
 
loss2 = l2.sum()
loss2.backward()
 
print(w1.grad)  # 输出:tensor(12.5)
print(w2.grad)  # 输出:tensor(5.)

示例中的梯度推导很简单,我在这篇博客里推了一下。从输出结果来看,程序确实是把两次的梯度加起来了。

附注:如果网络要进行两次反向传播,却没有用retain_graph=True,则运行时会报错:RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

 

posted @   Picassooo  阅读(8302)  评论(0编辑  收藏  举报
编辑推荐:
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
阅读排行:
· Manus爆火,是硬核还是营销?
· 终于写完轮子一部分:tcp代理 了,记录一下
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
点击右上角即可分享
微信分享提示