量化训练之可微量化参数——LSQ

(本文首发于公众号,没事来逛逛)

有读者让我讲一下 LSQ (Learned Step Size Quantization) 这篇论文,刚好我自己在实践中有用到,是一个挺实用的算法,因此这篇文章简单介绍一下。阅读这篇文章需要了解量化训练的基本过程,可以参考我之前的系列教程

LSQ 是 IBM 在 2020 年发表的一篇文章,从题目意思也可以看出,文章是把量化参数 step size (也叫 scale) 也当作参数进行训练。这种把量化参数也进行求导训练的技巧也叫作可微量化参数。在这之后,高通也发表了增强版的 LSQ+,把另一个量化参数 zero point 也进行训练,从而把 LSQ 推广到非对称量化中。

这篇文章就把 LSQ 和 LSQ+ 放在一起介绍了。由于两篇文章的公式符号不统一,为了防止符号错乱,统一使用 LSQ 论文中的符号进行表述。

普通量化训练

在量化训练中需要加入伪量化节点 (Fake Quantize),这些节点做的事情就是把输入的 float 数据量化一遍后,再反量化回 float,以此来模拟量化误差,同时在反向传播的时候,发挥 STE 的功能,把导数回传到前面的层。

Fake Quantize 的过程可以总结成以下公式 (为了方便讲解 LSQ,这里采用 LSQ 中的对称量化的方式):

\[\begin{align} \overline v&=round(clip(v/s, -Q_N,Q_P)) \tag{1} \\ \hat v&=\overline v \times s \tag{2} \end{align} \]

其中,\(v\) 是 float 的输入,\(\overline v\) 是量化后的数据 (仍然使用 float 来存储,但数值由于做了 round 操作,因此是整数),\(\hat v\) 是反量化的结果。\(-Q_N\)\(Q_P\) 分别是量化数值的最小值和最大值 (在对称量化中,\(Q_N\)\(Q_P\) 通常是相等的),\(s\) 是量化参数。

由于 round 操作会带来误差,因此 \(\hat v\)\(v\) 之间存在量化误差,这些误差反应到 loss 上会产生梯度,这样就可以反向传播进行学习。每次更新 weight 后,我们会得到新的 float 的数值范围,然后重新估计量化参数 \(s\)

\[s=\frac{|v|_{max}}{Q_P} \tag{3} \]

之后,开始新一次迭代训练。

LSQ

可以看到,上面这个过程的量化参数都是根据每一轮的权重计算出来的,而整个网络在训练的过程中只会更新权重的数值。

LSQ 想做的,就是把这里的 \(s\) 也放到网络的训练当中,而不是通过权重来计算。

也就是说,每次反向传播的时候,需要对 \(s\) 求导进行更新。

这个导数可以这样计算:把 (1)(2) 式统一一下得到:

\[\begin{align} \hat v&=round(clip(v/s, -Q_N, Q_P))\times s \tag{4} \\ &=\begin{cases}-Q_N \times s & v/s <= -Q_N \\ round(v/s)\times s & -Q_N < v/s < Q_P \tag{5} \\ Q_P \times s & v/s >= Q_P \end{cases} \end{align} \]

然后对 \(s\) 求导得到:

\[\frac{\partial \hat v}{\partial s}= \begin{cases} -Q_N & v/s <= -Q_N \\ round(v/s)+\frac{\partial round(v/s)}{\partial s}\times s & -Q_N< v/s<Q_P \tag{6} \\ Q_P & v/s >= Q_P \\ \end{cases} \]

\(round(v/s)\) 这一步的导数可以通过 STE 得到:

\[\begin{align} \frac{\partial round(v/s)}{\partial s}&=\frac{\partial (v/s)}{\partial s} \tag{7} \\ &=-\frac{v}{s^2} \notag \end{align} \]

最终得到论文中的求导公式:

\[\frac{\partial \hat v}{\partial s}= \begin{cases} -Q_N & v/s <= -Q_N \\ -\frac{v}{s}+round(v/s) & -Q_N< v/s <Q_P \tag{8} \\ Q_P & v/s >= Q_P \\ \end{cases} \]

(上面这堆公式敲得非常辛苦,给个赞不过分吧o)

作者在实验中发现,这种简单粗暴的训练方式有一个好处。

假设把量化范围固定在 [0, 3] 区间,(即 \(Q_N=0\)\(Q_P=3\))。下面 A 图表示量化前的 \(v\) 和反量化后的 \(\hat{v}\) 之间的映射关系(假设 \(s=1\)),这里面 round 采用四舍五入的原则,也就是说,在 0.5 这个地方 (图中第一道虚线),\(\hat{v}\) 是会从 0 突变到 1 的,从而带来巨大的量化误差。

因此,从 0.5 的左侧走到右侧,梯度应该是要陡然增大的。

在 B 图中,作者就对比了 QIL、PACT 和 LSQ (前面两个是另外两种可微量化参数的方法) 在这些突变处的梯度变化,结果发现,QIL 和 PACT 在突变处的梯度没有明显变化,还是按照原来的趋势走,而 LSQ 则出现了一个明显突变 (注意每条虚线右侧)。因此,LSQ 在梯度计算方面是更加合理的。

此外,作者还认为,在计算 \(s\) 梯度的时候,还需要兼顾模型权重的梯度,二者差异不能过大,因此,作者设计了一个比例系数来约束 \(s\) 的梯度大小:

\[R=\frac{\partial_s L}{s}/\frac{||\partial_w L||}{||w||} \approx 1 \tag{9} \]

同时,为了保持训练稳定,作者在 \(s\) 的梯度上还乘了一个缩放系数 \(g\),对于 weight 来说,\(g=1/\sqrt{N_W Q_P}\),对于 feature 来说,\(g=1/\sqrt{N_F Q_P}\)\(N_W\)\(N_F\) 分别表示 weight 和 feature 的大小。

而在初始化方面,作者采用 \(\frac{2|v|}{\sqrt{Q_P}}\) 的方式初始化 \(s\)

到这里,LSQ 的要点基本讲完了,其实,精华的部分就是把 \(s\) 作为量化参数进行训练,至于后面的梯度约束、初始化等,在不同网络结构、不同任务中可能需要灵活调整,没必要完全照论文来。

LSQ+

LSQ+ 的思路和 LSQ 基本一致,就是把零点 (zero point,也叫 offset) 也变成可微参数进行训练。

加入零点后,(1)(2) 式就变成了:

\[\begin{align} \overline v&=round(clip((v-\beta)/s, -Q_N,Q_P)) \tag{10} \\ \hat v&=\overline v \times s + \beta \tag{11} \end{align} \]

(高通这个零点计算方式和我之前使用的差得比较多,我自己使用的时候是遵照我之前文章的风格 \(v/s+\beta\) 来计算的,因此大家也可以灵活调整)

之后就是按照 LSQ 的方式分别计算导数 \(\frac{\partial \hat{v}}{\partial s}\)\(\frac{\partial \hat{v}}{\partial \beta}\),再做量化训练。

论文还给出了一些初始化 \(s\)\(\beta\) 的方式,但还是那句话,视具体任务、具体网络结构而定,可以自己调整 (比如我通常就按照 \(v\) 取 90% 左右的区间来估计 \(s\)\(\beta\) 的初始值),甚至你可以用 weight equalize 先预处理一遍网络的权重再来跑 LSQ+ 的算法。

实验

这两篇文章都只给出了分类任务的实验,我觉得应该增加一点别的任务来体现算法的通用性。这里就不列举实验结果了,感兴趣的同学可以看看论文。值得注意的一点是在低比特 (4bit 以下) 的情况下,精度也可以保持得比较好。

一点思考

我自己在一个 GAN 类型的网络上尝试过 LSQ+ 算法,当时被它的效果惊艳到。

这个问题的背景是这样的:最开始的时候,我用普通的量化训练 (8bit) 加上一些蒸馏的技巧来量化这个网络,结果和全精度模型差不多。后来,团队的小伙伴对这个 GAN 网络做了巨量的压缩,同时用了一些技巧大大增强了这个网络的生成能力。然后,我的量化算法在这个网络上就失效了,精度损失非常明显。期间尝试了很多种方案,但都没法拯救。

我自己在分析这个网络权重的时候,发现一个现象,随着网络被压缩得越来越小,权重的数值范围是在逐渐增大的,换句话说,这个网络本身的信息量在逐渐增大。对量化来说,这是件很可怕的事情,因为留给我量化的信息容量是固定的,就只有 8 比特。随着网络信息量增大,每次做量化训练时,round 带来的误差也会更大,这可能使得网络的梯度变得非常不稳定。甚至我会想,是不是 8 比特的信息量就不可能承载得了新网络的容量?

后来,在万念俱灰之下,尝试了 LSQ+ 算法,结果一下子把精度提高了一个档次,我感觉我又活过来了!事后分析的时候,我觉得一个很重要的原因就是:LSQ+ 在前向传播的时候,\(s\) 本身也在控制调整权重的数值分布,而且这种调整是可微的,可以用损失函数进行学习,是一种动态的调整。相比仅仅更新 weight 来调整数值分布的做法,LSQ 多了一条路径来学习。

最后,给需要做量化部署的同学提个醒,在导出量化模型进行部署时,需要根据训练好的 \(s\) 来确定权重的 minmax 大小,因为在 LSQ 的前向传播中,模型权重的数值范围是受 \(s\) 影响的,最终也是根据 \(s\) 反应到损失函数上的。

参考

欢迎关注我的公众号:大白话AI,立志用大白话讲懂AI。

posted @ 2022-04-24 23:21  大白话AI  阅读(2085)  评论(0编辑  收藏  举报