OBS/OBD/GPTQ

本文将系统性介绍GPTQ以及他爹,他爷爷,他太爷爷系列论文。GPTQ是目前极端流行的一个后训练量化算法。本文从他太爷爷开始讲起。

GPTQ使用的方法改进自OBC(Frantar, 2022),OBC来源于OBS(Hassibi, 1992),OBS来源于OBD(Lecun, 1990)

这是一种量化、剪枝用的算法

首先OBD是LeCun的论文提出的(1990):
https://proceedings.neurips.cc/paper/1989/file/6c9882bbac1c7093bd25041881277658-Paper.pdf
然后在之后被应用到了剪枝上,叫OBS(1993):
https://www.babak.caltech.edu/pubs/conferences/00298572.pdf

Optimal Brain Surgeon and General Network Pruning

首先,回忆下泰勒展开和麦克劳林级数:
对于任意一个函数p(x),可以在x=0处展开为:

p(x)=p(0)+p(0)x+12p(0)x2++1n!p(n)(0)xn+o(xn)

也就是当x的取值在0附近时,p(x)可以用上述多项式近似。
同理,如果要在其他的点展开,则展开式子为:

(1)p(x)=p(x0)+p(x0)(xx0)+12p(x0)(xx0)2+...+1n!p(n)(x0)(xx0)n+o(xn)

对于神经网络的损失函数E(W;X)同样可以展开,按(1)展开,假定展开处的权重是W0,则

E(W)=E(W0)+((EW)|W=W0)T(WW0)+12(WW0)TH(WW0)+o(||WW0||3)

上式表示权重取值在W0附近时的误差的多项式近似值。下文为了表述方便,在书写二阶偏导时,不写W=W0。移项得到:

E(W)E(W0)=(EW)T(WW0)+12(WW0)TH(WW0)+o(||WW0||3)

将变化量记作Δ,则上式变为:

(2)ΔE=(EW)TΔW+12ΔWTHΔW+O(||ΔW||3)

其中,H是hessian矩阵,也就是二阶导组成的矩阵。关于多元函数的泰勒展开,可以参考其他文章,例如:https://zhuanlan.zhihu.com/p/90496291

此处简要说明一下泰勒展开,因为上面那个知乎文章的图片是错误的。
引用自维基百科:https://zh.wikipedia.org/wiki/黑塞矩陣#性質
下文涉及到的二阶偏导矩阵用的是G(x),实际就是论文中的 Hessian矩阵
image

所以,多元函数的泰勒展开式为:

f(x)=f(x0)+f(x0)TΔx+12ΔxTG(x0)Δx+

其中,

f(x0)=[fx1fx2fxn]x0T

为函数f(x)x0(x1,x2,,xn)点的梯度,

G(x0)=[2fx122fx1x22fx1xn2fx2x12fx222fx2xn2fxnx12fxnx22fxn2]x0

12ΔxTG(x0)Δx拆解开计算,如下图(为了简便,直接把G(x0)11简写为了G11

(3)ΔxTG(x0)Δx=[Δx1Δx2Δxn][G11G12G1nG21G22G2nGn1Gn2Gnn][Δx1Δx2Δxn]=Δx1(Δx1G11+Δx2G21+)+Δx2(Δx1G12+Δx2G22+)+

其中,12ΔxTG(x0)Δx可以分解为两项的和(如上式推导):

12i=1n(Δxi)2Gii+12i!=jΔxiΔxjGij

代入(2),可以得到

(4)ΔE=(EW)TΔW+12i=1n(ΔWi)2Hii+12i!=jΔWiΔWjHij+O(||ΔW||3)

而在LeCun的OBD中,把交叉项忽略了!OBS没有忽略。

这就是论文OBS中的公式(1)。

OBS的思路是:

  1. 剪掉对模型误差影响最小的权重
  2. 更新其他权重

剪枝和更新都被表达为对权重施加一个变量ΔW

剪枝相当于在当前权重W0附近找一个权重W,该权重将部分元素剪掉,同时error的变化最小,也就是让ΔE最小。剪掉wp就是令权重矩阵中的权重wp=0,相当于Δwp+wp=0。每次剪枝一个权重,就相当于找到一个ΔW,使得(ΔW+W0)p=0,表达成数学公式就是:

eqTΔW+wq=0

其中eq是单位向量,只有在q的位置是1.

所以优化目标就变成了:

minq{minΔW(12ΔWTHΔW)|eqTΔW+wq=0}

这里忽略了一次项和三次及以上的项。因为训练到局部最优的网络,一阶导数是0(或者接近0,可以忽略不计),三阶项也忽略。

这就变成了带约束的极值求解问题,用拉格朗日乘子法:

L=12ΔWTHΔW+λ(eqTΔW+wq)

求解上式,得到:

ΔW=wqHqq1H1eq=wqHqq1H:,q1Lq=12wq2Hqq1

求解过程:
联立方程组求解,使得偏导数为0,求出ΔW即可。

LΔW=0Lλ=0

要求解这个方程组,我们首先求LΔW。由于本人矩阵论比较烂,不会直接求解矩阵导数,因此我们拆开来看。
我们看对一个元素的导数,LΔWi。要求解这个问题,先要从L=12ΔWTHΔW+λ(eqTΔW+wq)中抽取出和ΔWi相关的项。
先看前半部分,12ΔWTHΔW部分。由对公式(3)的推导,这部分中和ΔWi相关的项是:12ΔWi2Hii+12j!=iΔWiWjHij,那么这部分对ΔWi的导数是:HiiWi+12j!=iWjHij
那么现在就剩下后半部分λ(eqTΔW+wq),这部分对ΔWi的导数实际上恒为0。(读者可以尝试自行分类讨论,eqi==1时,Wq=ΔWi,正好是1+(-1)=0,当,eqi==0时,对ΔWi的导数自然是0)。
那么,求解LΔW=0,实际上等同于求解一个方程组,该方程组的变量为(W0,W1,,Wn),表示成系数矩阵形式就是:

[H0012H0112H0n12H10H1112H0n12Hn012Hn1Hnn][ΔW0ΔW1ΔWn]=0

接下来就是求解该方程组了,至于怎么求解,求解的结果对不对,本人也不知道。

OBS处理流程:

  1. 训练好一个收敛到最小误差的神经网络
  2. 计算H1
  3. 计算所有的Lq=12wq2Hqq1,如果一个权重的对整体误差E的增加很小,那么该权重就需要被删除,进入step4;否则进行step5
  4. 使用step 3中选择到的权重索引q来进行剪枝并更新剩余权重,转到step2
  5. 结束剪枝

Optimal Brain Compression: A Framework for Accurate Post-Training Quantization and Pruning

上述OBS方法很好,但是,对于大型语言模型,每个循环都要计算H1,假设矩阵的维度是d=drowdcol,那么hessian的维度就是dd,计算H1的复杂度是Θ(d3),而整体剪枝的循环轮数是O(d)(可以理解,权重越多肯定剪枝次数越多,因为一次剪枝一个权重,次数和权重数量正相关),所以整体的算法复杂度就是O(d4)

这样应用到大模型中,显然算法复杂度难以令人接受。所以本文提出了降低复杂度的方法。包括空间和时间复杂度。

The ExactOBS Algorithm 量化误差ΔE=||WlXlWl^Xl||22,目的是找到让ΔE最小的权重的索引,

这里还有个,关于用XXT近似二阶导的原因。其实在OBS论文里都有推导,可以看一下原论文第III部分。对于Y=WX,Loss = MSE,可以简化出来二阶导数是XX^T。这里做了一点近似,一个重要的近似就是扰动后的输出和原输出几乎一样。

posted @   王冰冰  阅读(2071)  评论(7编辑  收藏  举报
相关博文:
历史上的今天:
2022-07-13 tmux发送到某个会话指定命令/按键
点击右上角即可分享
微信分享提示
💬
评论
📌
收藏
💗
关注
👍
推荐
🚀
回顶
收起