本文将系统性介绍GPTQ以及他爹,他爷爷,他太爷爷系列论文。GPTQ是目前极端流行的一个后训练量化算法。本文从他太爷爷开始讲起。
GPTQ使用的方法改进自OBC(Frantar, 2022),OBC来源于OBS(Hassibi, 1992),OBS来源于OBD(Lecun, 1990)
这是一种量化、剪枝用的算法
首先OBD是LeCun的论文提出的(1990):
https://proceedings.neurips.cc/paper/1989/file/6c9882bbac1c7093bd25041881277658-Paper.pdf
然后在之后被应用到了剪枝上,叫OBS(1993):
https://www.babak.caltech.edu/pubs/conferences/00298572.pdf
Optimal Brain Surgeon and General Network Pruning
首先,回忆下泰勒展开和麦克劳林级数:
对于任意一个函数p(x),可以在x=0处展开为:
p(x)=p(0)+p′(0)x+12p′′(0)x2+⋯+1n!p(n)(0)xn+o(xn)
也就是当x的取值在0附近时,p(x)可以用上述多项式近似。
同理,如果要在其他的点展开,则展开式子为:
p(x)=p(x0)+p′(x0)(x−x0)+12p′′(x0)(x−x0)2+...+1n!p(n)(x0)(x−x0)n+o(xn)(1)
对于神经网络的损失函数E(W;X)同样可以展开,按(1)展开,假定展开处的权重是W0,则
E(W)=E(W0)+((∂E∂W)∣∣∣W=W0)T⋅(W−W0)+12(W−W0)T⋅H⋅(W−W0)+o(||W−W0||3)
上式表示权重取值在W0附近时的误差的多项式近似值。下文为了表述方便,在书写二阶偏导时,不写W=W0。移项得到:
E(W)−E(W0)=(∂E∂W)T⋅(W−W0)+12(W−W0)T⋅H⋅(W−W0)+o(||W−W0||3)
将变化量记作Δ,则上式变为:
ΔE=(∂E∂W)T⋅ΔW+12ΔWT⋅H⋅ΔW+O(||ΔW||3)(2)
其中,H是hessian矩阵,也就是二阶导组成的矩阵。关于多元函数的泰勒展开,可以参考其他文章,例如:https://zhuanlan.zhihu.com/p/90496291
此处简要说明一下泰勒展开,因为上面那个知乎文章的图片是错误的。
引用自维基百科:https://zh.wikipedia.org/wiki/黑塞矩陣#性質
下文涉及到的二阶偏导矩阵用的是G(x),实际就是论文中的 Hessian矩阵

所以,多元函数的泰勒展开式为:
f(x)=f(x0)+∇f(x0)TΔx+12ΔxTG(x0)Δx+⋯
其中,
∇f(x0)=[∂f∂x1∂f∂x2⋯∂f∂xn]Tx0
为函数f(x)在x0(x1,x2,⋯,xn)点的梯度,
G(x0)=⎡⎢
⎢
⎢
⎢
⎢
⎢
⎢
⎢
⎢
⎢
⎢
⎢
⎢
⎢
⎢
⎢
⎢⎣∂2f∂x21∂2f∂x1∂x2⋯∂2f∂x1∂xn∂2f∂x2∂x1∂2f∂x22⋯∂2f∂x2∂xn⋮⋮⋱⋮∂2f∂xn∂x1∂2f∂xn∂x2⋯∂2f∂x2n⎤⎥
⎥
⎥
⎥
⎥
⎥
⎥
⎥
⎥
⎥
⎥
⎥
⎥
⎥
⎥
⎥
⎥⎦x0
将12ΔxTG(x0)Δx拆解开计算,如下图(为了简便,直接把G(x0)11简写为了G11
ΔxTG(x0)Δx=[Δx1Δx2⋯Δxn]⎡⎢
⎢
⎢
⎢⎣G11G12⋯G1nG21G22⋯G2n⋮⋮⋯⋮Gn1Gn2⋯Gnn⎤⎥
⎥
⎥
⎥⎦⎡⎢
⎢
⎢
⎢⎣Δx1Δx2⋮Δxn⎤⎥
⎥
⎥
⎥⎦=Δx1(Δx1G11+Δx2G21+⋯)+Δx2(Δx1G12+Δx2G22+⋯)+⋯(3)
其中,12ΔxTG(x0)Δx可以分解为两项的和(如上式推导):
12n∑i=1(Δxi)2Gii+12∑i!=jΔxiΔxjGij
代入(2),可以得到
ΔE=(∂E∂W)T⋅ΔW+12n∑i=1(ΔWi)2Hii+12∑i!=jΔWiΔWjHij+O(||ΔW||3)(4)
而在LeCun的OBD中,把交叉项忽略了!OBS没有忽略。
这就是论文OBS中的公式(1)。
OBS的思路是:
- 剪掉对模型误差影响最小的权重
- 更新其他权重
剪枝和更新都被表达为对权重施加一个变量ΔW
剪枝相当于在当前权重W0附近找一个权重W,该权重将部分元素剪掉,同时error的变化最小,也就是让ΔE最小。剪掉wp就是令权重矩阵中的权重wp=0,相当于Δwp+wp=0。每次剪枝一个权重,就相当于找到一个ΔW,使得(ΔW+W0)p=0,表达成数学公式就是:
eTq⋅ΔW+wq=0
其中eq是单位向量,只有在q的位置是1.
所以优化目标就变成了:
minq{minΔW(12ΔWT⋅H⋅ΔW)|eTq⋅ΔW+wq=0}
这里忽略了一次项和三次及以上的项。因为训练到局部最优的网络,一阶导数是0(或者接近0,可以忽略不计),三阶项也忽略。
这就变成了带约束的极值求解问题,用拉格朗日乘子法:
L=12ΔWT⋅H⋅ΔW+λ(eTq⋅ΔW+wq)
求解上式,得到:
ΔW=−wqH−1qqH−1⋅eq=−wqH−1qqH−1:,qLq=12w2qH−1qq
求解过程:
联立方程组求解,使得偏导数为0,求出ΔW即可。
∂L∂ΔW=0∂L∂λ=0
要求解这个方程组,我们首先求∂L∂ΔW。由于本人矩阵论比较烂,不会直接求解矩阵导数,因此我们拆开来看。
我们看对一个元素的导数,∂L∂ΔWi。要求解这个问题,先要从L=12ΔWT⋅H⋅ΔW+λ(eTq⋅ΔW+wq)中抽取出和ΔWi相关的项。
先看前半部分,12ΔWT⋅H⋅ΔW部分。由对公式(3)的推导,这部分中和ΔWi相关的项是:12ΔW2iHii+12∑j!=iΔWiWjHij,那么这部分对ΔWi的导数是:HiiWi+12∑j!=iWjHij。
那么现在就剩下后半部分λ(eTq⋅ΔW+wq),这部分对ΔWi的导数实际上恒为0。(读者可以尝试自行分类讨论,eqi==1时,Wq=−ΔWi,正好是1+(-1)=0,当,eqi==0时,对ΔWi的导数自然是0)。
那么,求解∂L∂ΔW=0,实际上等同于求解一个方程组,该方程组的变量为(W0,W1,⋯,Wn),表示成系数矩阵形式就是:
⎡⎢
⎢
⎢
⎢
⎢
⎢⎣H0012H01⋯12H0n12H10H11⋯12H0n⋮⋮⋮⋮12Hn012Hn1⋯Hnn⎤⎥
⎥
⎥
⎥
⎥
⎥⎦⎡⎢
⎢
⎢
⎢⎣ΔW0ΔW1⋮ΔWn⎤⎥
⎥
⎥
⎥⎦=0
接下来就是求解该方程组了,至于怎么求解,求解的结果对不对,本人也不知道。
OBS处理流程:
- 训练好一个收敛到最小误差的神经网络
- 计算H−1
- 计算所有的Lq=12w2qH−1qq,如果一个权重的对整体误差E的增加很小,那么该权重就需要被删除,进入step4;否则进行step5
- 使用step 3中选择到的权重索引q来进行剪枝并更新剩余权重,转到step2
- 结束剪枝
Optimal Brain Compression: A Framework for Accurate Post-Training Quantization and Pruning
上述OBS方法很好,但是,对于大型语言模型,每个循环都要计算H−1,假设矩阵的维度是d=drow∗dcol,那么hessian的维度就是d∗d,计算H−1的复杂度是Θ(d3),而整体剪枝的循环轮数是O(d)(可以理解,权重越多肯定剪枝次数越多,因为一次剪枝一个权重,次数和权重数量正相关),所以整体的算法复杂度就是O(d4)。
这样应用到大模型中,显然算法复杂度难以令人接受。所以本文提出了降低复杂度的方法。包括空间和时间复杂度。
The ExactOBS Algorithm 量化误差ΔE=||WlXl−^WlXl||22,目的是找到让ΔE最小的权重的索引,
这里还有个,关于用XXT近似二阶导的原因。其实在OBS论文里都有推导,可以看一下原论文第III部分。对于Y=WX,Loss = MSE,可以简化出来二阶导数是XX^T。这里做了一点近似,一个重要的近似就是扰动后的输出和原输出几乎一样。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
2022-07-13 tmux发送到某个会话指定命令/按键