Deep Learning中的Large Batch Training相关理论与实践
背景
[作者:DeepLearningStack,阿里巴巴算法工程师,开源TensorFlow Contributor]
欢迎大家关注我的公众号,“互联网西门二少”,我将继续输出我的技术干货~
在分布式训练时,提高计算通信占比是提高计算加速比的有效手段,当网络通信优化到一定程度时,只有通过增加每个worker上的batch size来提升计算量,进而提高计算通信占比。然而一直以来Deep Learning模型在训练时对Batch Size的选择都是异常敏感的,通常的经验是Large Batch Size会使收敛性变差,而相对小一点的Batch Size才能收敛的更好。当前学术界和工业界已经有一些论文来论证Large Batch Size对收敛性的影响,甚至提出了一些如何使用Large Batch去提高收敛性的方法,本文将对这些论文的重点和脉络做一个梳理。
论文脉络梳理
Large Batch Training是目前学术界和工业界研究的热点,其理论发展非常迅速。但由于非凸优化和Deep Learning的理论研究本身还处于并将长期处于初级阶段,所以即使存在各种各样的理论解释和证明,Large Batch Training相关的理论也尚未得到彻底的解释。为了能够让读者能够更容易理解Large Batch Training当前的学术发展,也为了让论文的阅读更有脉络,我们把学术界中的相关论文按照观点的提出顺序作为梳理如下。下面列出的每篇论文后面都有其要点,便于读者阅读时有个大概的感觉。因为本篇主要梳理Large Batch Training的理论部分,所以会对重点的论文进行分析解释。
- 《ON LARGE-BATCH TRAINING FOR DEEP LEARNING: GENERALIZATION GAP AND SHARP MINIMA》:这篇论文解释了Large Batch Training使收敛性变差的原因:使用Large Batch更容易落入Sharp Minima,而Sharp Minima属于过拟合,所以其泛化性比较差。
- 《Accurate, Large Minibatch SGD:Training ImageNet in 1 Hour》:这是FaceBook提出的一篇极具争议性的论文,从实践上来说它的的复现难度也是比较大的。该论文从实践的角度出发,在ResNet上提出了一种针对Large batch training的训练方法,即learning rate scaling rule。当batch size相对于baseline增加N倍时,learning rate也要相应的增加N倍,但也指出batch size的提升有一个upper bound,超过这个值,泛化性依然会变得很差。这篇论文对learning rate scaling rule有一些公式推导,但并不本质,更多的是做了较强的假设。总体来说,这是一篇实验做得比较solid,但理论基础并不丰满的实践论文。
- 《A BAYESIAN PERSPECTIVE ON GENERALIZATION AND STOCHASTIC GRADIENT DESCENT》:这是Google发在ICLR 2018上的一篇理论和实验都比较完善的论文。因为在ResNet上已经有了Learning Rate Scaling Rule的成功经验,因此该论文从贝叶斯的角度解释了泛化性和SGD。论文的核心观点是指出了Batch Training相对于Full Batch Training来说引入了Noise,而Noise具有波动的效果,这在论文里被称为Flucturate,它可以在更新时在一定程度上偏离Sharp Minima,从而进入Broad Minima,进而有了较好的泛化性,所以Noise起了较大的作用。进一步的,论文中将SGD的更新公式进行进行分析,等价为一个微分方程的定积分结果,通过将SGD更新公式与微分方程进行等价,导出了Flucturate的表达式,确定了影响其值的变动因素,即和Learning Rate与Batch size有关。若把Flucturate看做常量,那么Learning Rate与Batch Size可以近似看做是线性关系,这与论文2中的Learning Rate Scaling Rule一致。总体来说,这篇论文数学理论相对丰满的解释了Learning Rate Scaling Rule。
- 《Don't Decay the Learning Rate, Increase the Batch Size》:这是Google发在ICLR 2018上的第二篇论文,这篇论文的实验和结论非常简单,但是理论基础依然来自于论文3,所以阅读此篇论文之前一定要精度论文3。该论文从推导出的Mini Batch SGD的Flucturate公式出发,提出了一种使用Large Batch Training的加速方法。因为在一个完整的模型训练过程中,通常会随着轮数的增加而适当对Learning Rate做Decay。通过论文3中给出的公式,即Flucturate固定时,Learning Rate与Batch Size成正比关系,引发了思考:究竟是Learning Rate本身需要Decay才能使训练过程继续,还是Learning Rate的Decay间接影响了Noise的Flucturate才能使训练过程继续?通过实验验证,真正影响训练过程的本质是Noise的Flucturate。因此我们考虑到Learning Rate与Batch Size的正比例关系,我们可以固定Learning Rate不变,而将Batch Size增加N倍来缩小Noise的Flucturate。定时增加Batch Size不但可以维持原有方式的Flucturate,还可以加速训练过程,减少Update的更新频次,增加计算通信占比,提高加速比。总体来说,该论文基于论文3为理论基础,提出了一种逐渐增加Batch Size提高计算加速比和收敛加速比的方法。
要点梳理
可以按顺序梳理成以下几个方面
理论基础
- 从贝叶斯理论角度出发,论证Broad Minima相对于Sharp Minima具有更好的泛化性
- 用贝叶斯理论解释泛化性是有效的
- 贝叶斯理论与SGD
- 随机偏微分方程的与Scaling Rule的推导
优化方法
- 使用Large Batch Training提高训练速度
理论基础
理论基础来自于论文《A BAYESIAN PERSPECTIVE ON GENERALIZATION AND STOCHASTIC GRADIENT DESCENT》,这里只对重点内容进行记录。
从贝叶斯理论角度出发,论证broad minima相对于sharp minima具有更好的泛化性
内容
这部分公式较多,但确实是贝叶斯的理论基础,所以尽量以简单的形式展现出来。首先假设某模型M只有一个参数w,训练样本为x,Label为y,那么可以跟据贝叶斯公式直接写出下面的等式
其中等号右面分母上的第一项可以看做似然函数
一般情况下,我们对模型参数的分布会做高斯假设
所以有
可以看出这个公式就是模型训练中Loss Function的主要部分,前面一项H(w;M)是Cost,而后面一项是正则项。我们要最小化Loss Function,本质上是最大化C(w;M)这一项。假设我们训练了两组模型参数,如何判断哪一个模型的泛化性更好?这里使用如下公式来判断。
等式右面的第二项是对模型的偏好因子,在这里应该均设置为1,消除偏置的影响。右边第一项我们叫做Bayesian Evidence Ratio,它描述了训练样本改变了我们对模型先验偏好的程度。为了计算这个比值,我们需要计算分子和分母。
使用泰勒展开式对C(w;M)在最优值w_0附近进行近似展开,得到如下式子。
至此,我们可以对上述公式的结果进行分析。上述公式中最后一项其实就是Occam Factor。通过分析我们也知道二阶导数正负衡量的是函数的凹凸性,而二阶导数的大小衡量和曲率相关。当C''(w_0)越大时,该位置附近就越弯曲,越接近sharp minima,进而导致P(y|x;M)的概率越低,这符合Occam Razor的原则,越简单的模型泛化性越好,这是因为简单的模型是在Broad Minima上。也可以提高正则系数对C''(w_0)进行惩罚,从而控制Occam factor,提高泛化性。当扩展到多个参数后,该公式如下所示。
分析方法相同,不再赘述。
小结
这一部分作者从贝叶斯理论出发,从公式上推导出了Occam Razor的结论,并且论证了落入Sharp Minima的模型泛化性较差的原因,同时也得出了正则项对Sharp Minima具有惩罚作用。
用贝叶斯理论解释泛化性是有效的
内容
这里作者借鉴了论文《Understanding deep learning requires rethinking generalization》中的实验来从贝叶斯理论解释泛化性,与ICLR 2017的这篇Best Paper使用的Deep Learning Model不同,作者使用了最简单的线性模型进行实验,原因是线性模型在计算Bayesian Evidence的时候比Deep Learning简单很多。具体的实验配置可以参考论文,这里直接给出图表。
注:Bayesian Evidence实际上是Log Bayesian Evidence,对上面的结果取了对数。
这个实验主要是为了证明Bayesian Evidence的曲线和Test Cross Entropy的变化趋势是一致的,并且也复现了《Understanding deep learning requires rethinking generalization》中呢Deep Learning Model的结果。
小结
这一节中的实验证明,使用贝叶斯理论解释泛化性是有效的,并且得出了预期一致的结果。
贝叶斯理论与SGD
内容
在得出Bayesian Evidence和泛化性是强相关关系的结论之后,作者再次对SGD产生了思考。因为无论是Large Batch还是Small Batch,他们都是Full Batch的近似结果,所以都会引入Noise。作者认为造成不同Batch Size产生不同泛化性的根本原因是Noise的Flucturate程度。一定程度的Noise可以逃离Sharp Minima,带领模型进入Bayesian Evidence较大的区域,即Broad Minima区域;而Batch Size越大,Noise的Flucturate就越小,就很容易陷入Sharp Minima。(这部分的公式推导在这里先不给出,因为这不是这篇文章的重点,有兴趣的同学可以关注这篇论文的附录A)这说明SGD的更新规则本身就带有了一些正则化效果,这个正则化的效果很大程度上来自于SGD本身引入的Noise。这与ICLR 2017 Best Paper《Understanding deep learning requires rethinking generalization》观察到的现象和得出的结论一致,该篇文章中主要思考的一个问题是,SGD在训练完全部样本之后,为什么不是记住所有的样本,而是还学到了一些泛化性?
回到这篇论文,作者认定一定存在一个最佳Batch Size,这个Batch Size既没有使模型进入Sharp Minima区域,又有一定的复杂性,使之让当前的模型效果最好。于是做了不同的实验,得到以下结果。
这些实验其实就是验证不同Batch Size训练出的模型在test集上的表现,并说明存在一个最佳的Batch Size,使用它训练出的模型,其泛化性优于其他Batch Size训练出的模型。
小结
这一部分从对贝叶斯与泛化性的思考入手,进而尝试解释SGD的特点,从而试图验证不同Batch Size对泛化性的影响。Batch Size的选取可以看成是Depth(Sharp)和Breadth(Broad)的Trade off,所以存在一个最佳的Batch Size,在其他超参数固定时使模型达到最好的泛化效果。
随机偏微分方程的与scaling rule的推导
内容
因为Batch Size的选取,从贝叶斯角度去理解,实际上就是Depth和Breadth的Trade off。所以可以更进一步的对SGD引入的Noise进行分析,进一步去探究这个Noise带来的Flucturate与哪些因素相关,这就需要和随机偏微分方程建立联系了。
首先,将SGD的update公式进行改写。
其中N代表训练集的样本数,ε代表学习率。假设我们用<>代表期望的计算,那么我们有
根据中心极限定理,我们可以得出以下结论
所以标准的Stochastic Gradient Descent可以看成是标准梯度加上一个Noise,这个Noise就是α中的内容。下面进一步研究Noise的性质。
其中,F(w)为梯度的协方差项,δ_ij代表了Indicator,即当i=j时,δ_ij=1,否则等于0。这是因为样本和样本之间是相互独立的关系,所以协方差应该等于0。如果看不懂这个公式可以按照下面的原型推理,一目了然。
根据协方差矩阵的可列可拆的性质,我们求得如下期望。
至此,Noise的统计特性已经全部计算出来,下面需要和随机偏微分方程进行等价。首先,SGD的Update规则是一个离散的过程,不是连续的过程。如果我们把SGD的每一步想象成为一个连续的可微分的过程,每次Update一个偏微分算子,那么可以将上述学习率为ε的Update公式看成是某个微分方程的定积分结果,下面先介绍这个偏微分方程(这个偏微分方程的产生来自于《Handbook of Stochastic Methods》)。
这里t是连续的变量,η(t)代表了t时刻的Noise,具有如下性质。
因为我们知道Noise的期望必定等于0,而方差会有个波动的Scale,且波动的大小是以F(w)有关,所以这个Scale我们用g来表示,即Flucturate。而SGD的Update规则可以改写如下所示。
为了探求g的变化因素,我们需要将偏微分方程的最后一项的方差和SGD的α方差对应起来,得到
上面最后的积分公式推导可能会有些迷惑,大概是会迷惑在积分的方差是如何化简到二重积分这一过程,其实积分符号只是个对连续变量的求和过程,所以依然可以使用协方差的可列可拆的性质,如果还是不习惯,将积分符合和dt换成求和符号再去使用协方差公式即可轻松得到结论。
所以,我们得到了相当重要的结论,这是在一定程度上能够解释Learning Rate Scaling Rule的结论。
所以,我们得到了结论,SGD引入了一些Noise,这个Noise具有一定的Flucturate,它的大小是和Batch Size成反比,与Learning Rate成正比。
小结
这一节使用偏微分方程和SGD的更新规则,经过一系列的数学推导,得到了SGD引入的Noise对更新过程的Flucturation大小与Batch size和Learning rate的关系。这是这篇论文十分重要的结论,也是Learning Rate Scaling Rule的理论基石。
理论总结
至此,理论基础部分梳理完毕,虽然公式较多较为复杂,但是结论却非常简单。作者从贝叶斯理论的角度出发,推导出了Occam Razor的形式表达,并从公式上论证了Sharp Minima相对于Broad Minima泛化性差的原因。而后又验证了Bayesian Evidence和模型泛化性一致的结论,进而从贝叶斯理论的角度对SGD的更新过程进行了猜测:SGD会引入Noise,而正是Noise的Flucturate帮助模型在更新过程中逃离Sharp Minima,进入更高的Bayesian Evidence区域,即Broad Minima,所以指出Batch Size的选择实际上是Noise Flucturate的调整,本质上是Sharp Minima和Broad Minima的Trade off。最后作者通过将SGD更新公式进行改写,并联合偏微分方程,得出了Noise的Fluctruate的形式表达,它Batch Size成反比,和Learning Rate成正比。
之前FAIR发表的论文《Accurate, Large Minibatch SGD:Training ImageNet in 1 Hour》中提出了Learning Rate Scaling Rule在ResNet上具有很好的效果,该论文在实验上做的比较充分,但是在理论上并没有特别Solid,而Google的这篇论文可以作为它的理论基石之一。
优化方法
优化方法来自论文《Don't Decay the Learning Rate, Increase the Batch Size》,这篇论文在理解完前一篇论文之后会显得非常简单,完全是一篇实验性论文,实验做得较为充分,这里只会对重要内容做个简单的梳理。
理论基础公式
对于SGD来说,Flucturation形式表达为
对于Momentum-SGD来说,形式表达为(公式推导来自于langvein动力学)
Large batch training的优化原理
无论是SGD还是Momentum-SGD,我们都可以发现g与Batch Size成反比,与Learning Rate成正比,而在一般的Deep Learning Model训练过程中,会在固定轮数对Learning Rate做Decay,这个过程让作者引发了思考,究竟在训练过程中,泛化性的提升是由于Learning Rate做Decay导致的,还是g发生变化导致的?如果是后者,那么定时增加Batch Size也应该会达到同样的效果,因此作者做了几组实验。
作者做了三组实验,一组是标准的对Learning Rate做Decay,一组是固定Rearning Rate不变,在原来发生Learning Rate Decay的轮数将Batch Size扩大N倍(N是Learning Rate Decay的Factor,即与Learning Rate的Decay为相同力度)。另一组是二者的结合Hybrid,即先Learning Rate Decay,后变化Batch Size。实验证明三者的泛化性曲线相同,所以证明了Learning Rate Decay实际上是对g做了Scale down。然而增加Batch Size不但可以达到同样的效果,还能提高计算通信占比,并且在整体训练过程中减少Update的次数,这是Increase Batch Size Training的优化点。
关于Momentum-SGD
在Momentum-SGD的flucturation形式表达中,我们还看到了momentum的作用,即增加m的值可以增加g的值。但是实验证明,增加m同时扩大batch size得到的泛化性相对于改变learning rate和batch size要差一些。这是因为提高momentum会使Momentum-SGD中的accumulator需要更多的轮数才能到达稳定的状态,而在到达稳定状态之前,update的scale是会被supressed的,作者在论文附录中论证了这一观点,这里不再详细赘述。后续的实验也证明了这一点。
更大Batch Size和消除Warm Up
在论文《Accurate, Large Minibatch SGD:Training ImageNet in 1 Hour》中,作者实验的最大Batch Size为8192。然而在这篇论文中,作者使用更大的初始Batch Size(最大尝试到65536)对ImageNet进行训练,并且在固定的轮数对Noise做Decay(增加Batch Size)。作者消去了Warm Up的过程,但是引入了Mometum的超参调优,当使用更大Batch Size时,不仅调整初始Learning Rate,还增加m值来进一步放大Noise,帮助训练过程逃离Sharp Minima。实验效果如下。
小结
此篇论文更像是《A BAYESIAN PERSPECTIVE ON GENERALIZATION AND STOCHASTIC GRADIENT DESCENT》工作的延续,以该篇论证的理论基础出发,得出了一种提高训练计算加速比和收敛加速比的方法。结论和实验比较简单,但背后的数学推导较为复杂。
总结
工业界的分布式算力提升对Large Batch Training提出了需求,因为增加Batch Size显然是提高计算通信占比的最佳方式,所以Large Batch Training固有的收敛性问题就成为了学术界研究的重点方向。本文通过梳理近些年来学术界对Large Batch Training的论文研究,从理论角度阐述了Large Batch Training造成收敛性较差的原因——容易陷入Broad Minima。而Google发表的论文从贝叶斯角度给出了另外的解释——不同Batch Size训练引入的Noise不同造成Fluctuate也不同,最终导致收敛性的不同。为了验证这一观点,Google又从实践角度给出了验证——通过固定Learning Rate,逐步增大Batch Size来稳定Fluctuate,达到使用大Batch Size加速训练的目的。截止到目前,这些理论方面的论证和解释依然处于蓬勃发展之中,未来还会有更深入研究在学术界中出现。
博客园——DeepLearningStack,未经授权严禁转载