20 Pytorch 求导机制
Pytorch 求导机制
参考链接:
https://zhuanlan.zhihu.com/p/38475183(报错解决,情况1)
https://blog.csdn.net/m0_38129460/article/details/90405086(Inplace operation)
https://zhuanlan.zhihu.com/p/113112455(链式法则)
https://zhuanlan.zhihu.com/p/33378444(计算图)
https://www.jianshu.com/p/ff74ccae25f3(requires_grad_(), detach(), torch.no_grad()的区别)
https://zhuanlan.zhihu.com/p/69294347 & https://zhuanlan.zhihu.com/p/67184419(Autograd、detach()好文!)
昨天代码报了个错:RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.
然后就想着把 torch 的自动求导机制搞懂,上面的几篇博客基本涵盖了,我简单总结下:
1、计算图:
- 计算图通常包含两种元素,一个是 tensor,另一个是 Function。
- Function 指的是在计算图中某个节点(node)所进行的运算,比如加减乘除卷积等等之类的,Function 内部有
forward()
和backward()
两个方法,分别应用于正向、反向传播。 - torch构建的计算图是动态图,即为了节约内存,所以每次一轮迭代完之后计算图就被在内存释放。
- 叶子张量(leaf tensor):用户创建的,其导数要保留,有用。
2、链式法则:
3、Inplace operation:
Inplace操作会改变对应内存的值,如 += ,sum_()等,一般带下划线的都是Inplace操作
强行替换会导致在求导过程中报错,详情见https://zhuanlan.zhihu.com/p/38475183
4、Tensor.requires_grad_()、Tensor.requires_grad()、Tensor.detach()、Tensor.data():
requires_grad_()
的主要用途是告诉自动求导开始记录对Tensor
的操作,即requires_grad=True
requires_grad()
是查看是否要计算梯度
detach()
函数会返回一个新的Tensor
对象b
,并且新Tensor
是与当前的计算图分离的,其requires_grad
属性为False
torch.no_grad()
是一个上下文管理器,用来禁止梯度的计算,通常用来网络推断中,它可以减少计算内存的使用量
5、Debug 的时候可以用 with torch.autograd.set_detect_anomaly(True)
详情见:https://github.com/pytorch/pytorch/issues/15803
另外,计算图的可视化可以使用Torchviz