关于RNN (循环神经网络)相邻采样为什么在每次迭代之前都需要将参数detach
关于RNN (循环神经网络)相邻采样为什么在每次迭代之前都需要将参数detach
-
这个问题出自《动手学深度学习pytorch》中RNN 第六章6.4节内容,如下图所示:
当时看到这个注释,我是一脸懵逼,(难道就不能解释清楚一点嘛,让我独自思考了那么长时间,差评!!!)我主要有以下疑惑:
-
每次小批量反向传播之后,由于torch是动态计算图,本质上该次的计算图已经销毁,与下次小批量迭代的构建的计算没有任何关联,detach不是多此一举嘛?
-
按照注释所说的,难道下次小批量构建的计算图由于初始隐藏状态引用于上次小批量迭代最后的时间步长的隐藏状态,这样计算图存在分支关联,方向传播会经过以前所有批量迭代构建的计算图,导致内存和计算资源开销大?
- 带着这两个疑惑,我开始面向百度编程(网上的博客真的是千篇一律啊,10篇当中9篇一样,哎世风日下,我也是服了,文章转来转去有意思嘛,自己收藏着看看不好嘛,非得全篇复制还转载,真的***)百度之后,我发现了以下解释(没一个有用的)
-
胡说八道型
这讲的啥?按你这么说,state是叶子节点了(估计不知道从哪抄的错误博客,害人匪浅啊),既然state都是叶子节点了,那还跟上一次批量的计算图有毛关系,反向传播个屁?叶子节点的定义:一棵树当中没有子结点(即度为0)的结点称为叶子结点。除了第一次小批量的初始隐藏状态是叶子节点外,其他批量的隐藏状态都经过隐藏层的计算,所以state已经不再是叶子节点了,而是分支节点(即grad_fn属性不为None的节点)不信,现场测试:
将源代码略微添加以上代码,验证是否为叶子节点:
看出来了吧,除了第一个小批量state 是叶子节点,其他都不是。
-
理解不到位型
哎,这张祸害不浅的知乎转载图:Z不是叶子节点,他是经过计算的节点(其他内容不粘贴了)
-
既然不是叶子节点,那detach到底有什么作用呢
首先要明确一个意识:pytorch是动态计算图,每次backward后,本次计算图自动销毁,但是计算图中的节点都还保留。
方向传播直到叶子节点为止,否者一直传播,直到找到叶子节点
我的答案是有用,但根本不是为了防止梯度开销过大(注释真的害人不浅啊),detach的真正作用是梯度节流,防止反向传播传播到隐藏状态时,因为上次小批量方向传播计算图的销毁导致继续向下传播而引起报错。啥意思呢,我以连续两次小批量迭代举例:
第一次小批量迭代,H0 是叶子节点,因为他没经过任何计算。剩余H1是非叶子节点。在第一次方向传播后,第一次的计算图已经销毁,但是节点数据仍然存在。
第二次小批量迭代,第一次批量迭代的最后时间节点的隐藏状态H2 成为第二批次小的初始隐藏状态( H0(第二次) = H2(第一次) ),这样第二次在方向传播时,当传播到H0时,发现H0 是 分支节点(grad_fn+requires_grad) ,就会继续向下传播直到找到叶子节点为止,但是可惜的是H0 之后的计算图(即第一次小批量的计算图)已经销毁,传播发生中断,因此就会导致出错。而使用detach之后,H0 自然与上次的计算图没有任何关系,H0自身变为叶子节点,这样传播到H0时自然就结束了。
好了,验证我所说的吧。
- 首先,不使用detach,会导致传播报错
将detach 操作删除
运行结果:
看到没,第二次在方向传播时出错了吧
-
使用detach,防止出错,并使H0 变为叶子节点
代码更改如下:
结果:全是true
综上:detach在这里作用,大家明白不,喜欢点个赞!!!!
至于书中为什么将detach的作用注释成那样呢,我想作者在翻译成torch的时候,忽略了MAXNET框架(原书是maxnet框架)与pytorch的区别。 MaxNet是支持静态图的,所以对于MaxNet ,detach的作用是与注释相同的,但是pytorch是动态图,所以作用在这里就不同了!!!