从零开始学习MXnet(五)MXnet的黑科技之显存节省大法
写完发现名字有点拗口。。- -#
大家在做deep learning的时候,应该都遇到过显存不够用,然后不得不去痛苦的减去batchszie,或者砍自己的网络结构呢? 最后跑出来的效果不尽如人意,总觉得自己被全世界针对了。。遇到这种情况怎么办? 请使用MXnet的天奇大法带你省显存! 鲁迅曾经说过:你不去试试,怎么会知道自己的idea真的是这么糟糕呢?
首先是传送门附上 mxnet-memonger,相应的paper也是值得一看的 Training Deep Nets with Sublinear Memory Cost。
实际上repo和paer里面都说的很清楚了,这里简单提一下吧。
一、Why?
节省显存的原理是什么呢?我们知道,我们在训练一个网络的时候,显存是用来保存中间的结果的,为什么需要保存中间的结果呢,因为在BP算梯度的时候,我们是需要当前层的值和上一层回传的梯度一起才能计算得到的,所以这看来显存是无法节省的?当然不会,简单的举个例子:一个3层的神经网络,我们可以不保存第二层的结果,在BP到第二层需要它的结果的时候,可以通过第一层的结果来计算出来,这样就节省了不少内存。 提醒一下,这只是我个人的理解,事实上这篇paper一直没有去好好的读一下,有时间在再个笔记。不过大体的意思差不多就是这样。
二、How?
怎么做呢?分享一下我的trick吧,我一般会在symbol的相加的地方如data = data+ data0这种后面加上一行 data._set_attr(force_mirroring='True'),为什么这么做大家可以去看看repo的readme,symbol的地方处理完以后,只有如下就可以了,searchplan会返回一个可以节省显存的的symbol给你,其它地方完全一样。
1 import mxnet as mx 2 import memonger 3 4 # configure your network 5 net = my_symbol() 6 7 # call memory optimizer to search possible memory plan. 8 net_planned = memonger.search_plan(net) 9 10 # use as normal 11 model = mx.FeedForward(net_planned, ...) 12 model.fit(...)
PS:使用的时候要注意,千万不要在又随机性的层例如dropout后面加上mirror,因为这个结果,再算一次就和上一次不同了,会让你的symbol的loss变得很奇怪。。
三、总结
天奇大法吼啊!