[LLM] LLM后量化(PTQ)总结及原理实现
LLM后量化(PTQ)总结及原理实现
weight only
per_channel:按照每个channel的方式,计算得到scale和zero参数,通过weight = weight * scale + zero的方式进行还原。
per_channel_group_wise:按照每个channel的方式,在per_channel的基础上产生一个scale,再增加了group_wise, 即每个channel内部再进行一次group的scale和zero,相当于更细粒度的量化.
TensorRT-LLM中的量化的gemm实现
以下是加载half 权重和反量化的代码,在TensorRT中,两个half在一个32bit中存储,形成half2数据类型,以便于混合计算。
for (int idx = 0; idx < NPerBlock; ++idx)
{
for (int i = 0; i < Details::kShuffleContinous; ++i)
{
for (int j = 0; j < Details::kShuffleStrided; ++j)
{
// Dequantize the weights and arrange the shuffled elements back to the correct order in the
// register array
half2 v = *reinterpret_cast<half2*>(weights_vec + i * Details::kShuffleBasicTile
+ j * Details::kShuffleContinous * Details::kShuffleBasicTile);
v = __hfma2(v, __half2half2(scale[idx]), __half2half2(zero[idx]));
weights_f16[(i * Details::kShuffleStrided * Details::kShuffleBasicTile
+ j * Details::kShuffleBasicTile + 0)
* NPerBlock
+ idx]
= v.x;
weights_f16[(i * Details::kShuffleStrided * Details::kShuffleBasicTile
+ j * Details::kShuffleBasicTile + 1)
* NPerBlock
+ idx]
= v.y;
}
}
}
其中__hfma2是half2下的乘加运算,那么v = v * scale + zero;只不过这里同时计算了两个half。
v = __hfma2(v, __half2half2(scale[idx]), __half2half2(zero[idx]));
下面这段代码是计算 v = w * v + v 的过程,也即是将v累加到每一个batch结果中的过程,当每个block计算N个时可直接累加,否则则需要拆开计算。有所不同的是,使用的是half2数据类型进行计算。
这里accumulator是half类型的全局内存。
half accumulator[Num];
for (int b = 0; b < Batch; ++b)
{
half in_v[Details::kElemsPerThread];
// Perform vector inner product and accumulate
if constexpr (NPerBlock == 1)
{
half2 v = __float2half2_rn(0.f);
for (int y = 0; y < Details::kElemsPerThread; y += 2)
{
v = __hfma2(*reinterpret_cast<half2*>(weights_f16 + y), *reinterpret_cast<half2*>(in_v + y), v);
}
accumulator[b] += __hadd(v.x, v.y);
}
else
{
for (int x = 0; x < NPerBlock / 2; ++x)
{
for (int y = 0; y < Details::kElemsPerThread; ++y)
{
*reinterpret_cast<half2*>(accumulator + b * NPerBlock + x * 2)
= __hfma2(*reinterpret_cast<half2*>(weights_f16 + y * NPerBlock + x * 2),
__half2half2(in_v[y]), *reinterpret_cast<half2*>(accumulator + b * NPerBlock + x * 2));
}
}
}
}
float reses[Num];
for (int i = 0; i < Num; ++i)
{
reses[i] = __half2float(accumulator[i]);
}
// Each warp completes the internal reduce and writes the [Batch * NPerBlock * Interleave] results to the
// corresponding address in shared memory
Details::Layout::sync<Num, WarpSize>(reses, sm);
// Each thread is responsible for the accumulation and store to global memory of one element
for (int i = tid; i < Num * Interleave; i += BlockSize)
{
int nid = i % (NPerBlock * Interleave);
float v = 0.f;
for (int j = 0; j < BlockSize / WarpSize; ++j)
{
v += sm[j][i];
}
float bias_v = 0.f;
if constexpr (Bias)
{
bias_v = __half2float(bias[n_start_id + nid]);
}
int b = i / NPerBlock / Interleave;
out[b * n + n_start_id + nid] = __float2half_rn(ActOp<float>::apply(v + bias_v));
}
smooth_quant
主要基于以下几点观察:
- 激活值比权重更加难以量化
- 异常值的存在让激活值更加难以量化
- 异常值通常出现在固定的channel中
所以smooth_quant的做法是将激活值和weight同时放缩一定倍数,这样异常激活值就可以被平滑,进而使得激活的量化不那么困难。
因为放缩本身就是乘加操作,所以可以将attn中量化操作融合到前一步中的RMSNORM操作中,节省量化开销。
按照粒度的不同,可以分为per-channel、per-token、per-tensor等几种不同的粒度。
其计算公式如下:
act_scales = defaultdict(lambda: {"x": None, "y": None, "w": None})
# 统计的是每一层中x,y,w的最大值
weight_scales = max_abs_value in per channel
scale = activation^alpha / (weights ^ (1-alpha))
gemm *= scale
rmsnorm_weights /= scale
计算出scale保存,在推理阶段,因为rms_norm已经除了scale,所以X不需要额外的操作,使用普通的gemm计算即可。
GPTQ
GPTQ: ACCURATE POST-TRAINING QUANTIZATION FOR GENERATIVE PR E-TRAINED TRANSFORMERS
GPTQ 将权重分组(如:128列为一组)为多个子矩阵(block)。对某个 block 内的所有参数逐个量化,每个参数量化后,需要适当调整这个 block 内其他未量化的参数,以弥补量化造成的精度损失。因此,GPTQ 量化需要准备校准数据集。
使用 Cholesky 分解中 Hessian 矩阵的逆,在给定的step中对连续列的块进行量化,并在step结束时更新剩余的权重。
取消贪心算法:OBS 采用贪心策略,先量化对目标影响最小的参数;但 GPTQ 发现直接按顺序做参数量化,对精度影响也不大。这项改进使得参数矩阵每一行的量化可以做并行的矩阵计算。
Lazy Batch-Updates:OBQ 对权重一个个进行单独更新,作者发现性能瓶颈实际在于GPU的内存带宽,而且同一个特征矩阵W不同列间的权重更新是不会互相影响的。因此作者提出了延迟批处理的方法,通过延迟一部分参数的更新,一次处理多个(如:128)列,来缓解带宽的压力,大幅提升了计算速度。
Cholesky 分解:用 Cholesky 分解求海森矩阵的逆,提前计算好所有需要的信息,在增强数值稳定性的同时,后续更新的过程中再计算,进一步减少了计算量。
columns = w.shape[1]
H = torch.zeros((self.columns, self.columns), device=self.dev)
# 计算海森矩阵的逆向
dead = torch.diag(H) == 0
H[dead, dead] = 1
W[:, dead] = 0
damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(self.columns, device=self.dev)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H
# 利用海森矩阵的逆依次计算误差,并更新后续参数
for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
count = i2 - i1
W1 = W[:, i1:i2].clone()
Q1 = torch.zeros_like(W1)
Err1 = torch.zeros_like(W1)
Losses1 = torch.zeros_like(W1)
Hinv1 = Hinv[i1:i2, i1:i2]
for i in range(count):
w = W1[:, i]
d = Hinv1[i, i]
# 量化当前的weight
q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
Q1[:, i] = q
Losses1[:, i] = (w - q) ** 2 / d**2
err1 = (w - q) / d
# 在同一block内, 更新后续的weight
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
Err1[:, i] = err1
Q[:, i1:i2] = Q1
Losses[:, i1:i2] = Losses1 / 2
# 在不同blcok间更新后续weight
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
上述代码实现是Auto_GPTQ中的实现,代码可分为两个部分,第一部分使用cholesky分解方法求得海森矩阵的逆;第二部分则根据scale和zero对w进行量化,依次得到err和loss,并据此依次更新后续的矩阵。注意这里我把group_size部分缩减了。
其计算公式为\(q = quantize(w.unsqueeze(1)); err1 = (w - q) / d; W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))\),这里d为海森逆矩阵对角线上的值,需要在同一block内和不同block间顺序更新后续权重参数。
quantizer中为普通的量化实现,quantize如下所示,会将原有的浮点数scale到对应的int8或in4区间。实现如下:
# 对称量化操作
scale = (xmax - xmin) / self.maxq
zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
# 非对称量化操作
zero = torch.round(-xmin / self.scale)
def quantize(x, scale, zero, maxq):
if maxq < 0:
return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
return scale * (q - zero)
AWQ
权重对于LLM的性能并不同等重要,但要找到显著的权重通道,我们应该根据激活分布而不是权重分布,AWQ可以看作是smooth_quant的改进版。
自动搜索最优缩放,使全部权重下的量化误差最小,采用grid_search的方法对scale进行搜索,以保证最终Loss的损失值最小。
只测量每个通道的平均幅度误差来确定每个通道权重的重要性。
autoawq中的代码实现如下所示,所有过程分为4步:
- 计算每个channel weight的最大值
- 计算x的最大值
- 计算当前module的fp16输出
- 计算更新最大的scale
# [STEP 1]: Compute maximum of weight
weight = torch.cat([_m.weight for _m in layers], dim=0)
org_shape = weight.shape
weight = weight.view(-1, self.group_size)
w_scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
w_scale = w_scale.view(org_shape)
w_max = w_scale.mean(0)
clear_memory(weight)
# [STEP 2]: Compute maximum of x
x_max = inp.abs().view(-1, inp.shape[-1]).mean(0)
# [STEP 3]: Compute output of module
with torch.no_grad():
fp16_output = module2inspect(inp, **kwargs)
if isinstance(fp16_output, tuple):
fp16_output = fp16_output[0]
# [STEP 4]: Compute loss
best_scales = self._compute_best_scale(
inp, w_max, x_max, module2inspect,
layers, fp16_output, kwargs
)
具体的网格搜索算法如下所示,将按照网格对[0, 1]内的数进行遍历scale,计算出使得L2(fp16 - quant_v)的最佳scale。
- 网格遍历scale,这里n_grid=20
- 按照smooth_quant中的方式计算scale,这里多了一步平滑
- weight *= scale;将伪量化算子插入到模型中
- 计算fp16和quant_v之间的误差,保留误差最小时的scale
for ratio in range(n_grid):
# create new scales
ratio = ratio / n_grid
# NOTE: s^-1 * x is fused here, according to paper
scales = (x_max.pow(ratio) / w_max.pow(1-ratio)).clamp(min=1e-4)
scales = scales / (scales.max() * scales.min()).sqrt()
scales_view = scales.view(1, -1).to(device)
# Q(W * s)
for fc in linears2scale:
fc.weight.mul_(scales_view)
fc.weight.data = self.pseudo_quantize_tensor(fc.weight.data) / scales_view
# W * X
int_w_output = module2inspect(x, **kwargs)
if isinstance(int_w_output, tuple):
int_w_output = int_w_output[0]
# compute mean squared error (L2 norm)
loss = (fp16_output - int_w_output).float().pow(2).mean().item() # NOTE: float prevents overflow
history.append(loss)
if loss < best_error:
best_error = loss
best_ratio = ratio
best_scales = scales.clone()