[LLM] LLM后量化(PTQ)总结及原理实现

LLM后量化(PTQ)总结及原理实现

weight only

per_channel:按照每个channel的方式,计算得到scale和zero参数,通过weight = weight * scale + zero的方式进行还原。

per_channel_group_wise:按照每个channel的方式,在per_channel的基础上产生一个scale,再增加了group_wise, 即每个channel内部再进行一次group的scale和zero,相当于更细粒度的量化.

TensorRT-LLM中的量化的gemm实现

以下是加载half 权重和反量化的代码,在TensorRT中,两个half在一个32bit中存储,形成half2数据类型,以便于混合计算。

for (int idx = 0; idx < NPerBlock; ++idx)
{
    for (int i = 0; i < Details::kShuffleContinous; ++i)
    {
        for (int j = 0; j < Details::kShuffleStrided; ++j)
        {
            // Dequantize the weights and arrange the shuffled elements back to the correct order in the
            // register array
            half2 v = *reinterpret_cast<half2*>(weights_vec + i * Details::kShuffleBasicTile
                + j * Details::kShuffleContinous * Details::kShuffleBasicTile);
            v = __hfma2(v, __half2half2(scale[idx]), __half2half2(zero[idx]));
            weights_f16[(i * Details::kShuffleStrided * Details::kShuffleBasicTile
                            + j * Details::kShuffleBasicTile + 0)
                    * NPerBlock
                + idx]
                = v.x;
            weights_f16[(i * Details::kShuffleStrided * Details::kShuffleBasicTile
                            + j * Details::kShuffleBasicTile + 1)
                    * NPerBlock
                + idx]
                = v.y;
        }
    }
}

其中__hfma2是half2下的乘加运算,那么v = v * scale + zero;只不过这里同时计算了两个half。

v = __hfma2(v, __half2half2(scale[idx]), __half2half2(zero[idx]));

下面这段代码是计算 v = w * v + v 的过程,也即是将v累加到每一个batch结果中的过程,当每个block计算N个时可直接累加,否则则需要拆开计算。有所不同的是,使用的是half2数据类型进行计算。
这里accumulator是half类型的全局内存。

half accumulator[Num];
for (int b = 0; b < Batch; ++b)
{
    half in_v[Details::kElemsPerThread];
    
    // Perform vector inner product and accumulate
    if constexpr (NPerBlock == 1)
    {
        half2 v = __float2half2_rn(0.f);

        for (int y = 0; y < Details::kElemsPerThread; y += 2)
        {
            v = __hfma2(*reinterpret_cast<half2*>(weights_f16 + y), *reinterpret_cast<half2*>(in_v + y), v);
        }
        accumulator[b] += __hadd(v.x, v.y);
    }
    else
    {

        for (int x = 0; x < NPerBlock / 2; ++x)
        {
            for (int y = 0; y < Details::kElemsPerThread; ++y)
            {
                *reinterpret_cast<half2*>(accumulator + b * NPerBlock + x * 2)
                    = __hfma2(*reinterpret_cast<half2*>(weights_f16 + y * NPerBlock + x * 2),
                        __half2half2(in_v[y]), *reinterpret_cast<half2*>(accumulator + b * NPerBlock + x * 2));
            }
        }
    }
}

float reses[Num];
for (int i = 0; i < Num; ++i)
{
    reses[i] = __half2float(accumulator[i]);
}

// Each warp completes the internal reduce and writes the [Batch * NPerBlock * Interleave] results to the
// corresponding address in shared memory
Details::Layout::sync<Num, WarpSize>(reses, sm);

// Each thread is responsible for the accumulation and store to global memory of one element
for (int i = tid; i < Num * Interleave; i += BlockSize)
{
    int nid = i % (NPerBlock * Interleave);
    float v = 0.f;
    for (int j = 0; j < BlockSize / WarpSize; ++j)
    {
        v += sm[j][i];
    }
    float bias_v = 0.f;
    if constexpr (Bias)
    {
        bias_v = __half2float(bias[n_start_id + nid]);
    }
    int b = i / NPerBlock / Interleave;
    out[b * n + n_start_id + nid] = __float2half_rn(ActOp<float>::apply(v + bias_v));
}

smooth_quant

主要基于以下几点观察:

  1. 激活值比权重更加难以量化
  2. 异常值的存在让激活值更加难以量化
  3. 异常值通常出现在固定的channel中

所以smooth_quant的做法是将激活值和weight同时放缩一定倍数,这样异常激活值就可以被平滑,进而使得激活的量化不那么困难。
因为放缩本身就是乘加操作,所以可以将attn中量化操作融合到前一步中的RMSNORM操作中,节省量化开销。
按照粒度的不同,可以分为per-channel、per-token、per-tensor等几种不同的粒度。

其计算公式如下:

act_scales = defaultdict(lambda: {"x": None, "y": None, "w": None})
# 统计的是每一层中x,y,w的最大值
weight_scales = max_abs_value in per channel 
scale = activation^alpha / (weights ^ (1-alpha))
gemm *= scale
rmsnorm_weights /= scale

计算出scale保存,在推理阶段,因为rms_norm已经除了scale,所以X不需要额外的操作,使用普通的gemm计算即可。

GPTQ

GPTQ: ACCURATE POST-TRAINING QUANTIZATION FOR GENERATIVE PR E-TRAINED TRANSFORMERS

GPTQ 将权重分组(如:128列为一组)为多个子矩阵(block)。对某个 block 内的所有参数逐个量化,每个参数量化后,需要适当调整这个 block 内其他未量化的参数,以弥补量化造成的精度损失。因此,GPTQ 量化需要准备校准数据集。
使用 Cholesky 分解中 Hessian 矩阵的逆,在给定的step中对连续列的块进行量化,并在step结束时更新剩余的权重。

取消贪心算法:OBS 采用贪心策略,先量化对目标影响最小的参数;但 GPTQ 发现直接按顺序做参数量化,对精度影响也不大。这项改进使得参数矩阵每一行的量化可以做并行的矩阵计算。

Lazy Batch-Updates:OBQ 对权重一个个进行单独更新,作者发现性能瓶颈实际在于GPU的内存带宽,而且同一个特征矩阵W不同列间的权重更新是不会互相影响的。因此作者提出了延迟批处理的方法,通过延迟一部分参数的更新,一次处理多个(如:128)列,来缓解带宽的压力,大幅提升了计算速度。
Cholesky 分解:用 Cholesky 分解求海森矩阵的逆,提前计算好所有需要的信息,在增强数值稳定性的同时,后续更新的过程中再计算,进一步减少了计算量。

columns = w.shape[1]
H = torch.zeros((self.columns, self.columns), device=self.dev)

# 计算海森矩阵的逆向
dead = torch.diag(H) == 0
H[dead, dead] = 1
W[:, dead] = 0
damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(self.columns, device=self.dev)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H

# 利用海森矩阵的逆依次计算误差,并更新后续参数
for i1 in range(0, self.columns, blocksize):
    i2 = min(i1 + blocksize, self.columns)
    count = i2 - i1

    W1 = W[:, i1:i2].clone()
    Q1 = torch.zeros_like(W1)
    Err1 = torch.zeros_like(W1)
    Losses1 = torch.zeros_like(W1)
    Hinv1 = Hinv[i1:i2, i1:i2]

    for i in range(count):
        w = W1[:, i]
        d = Hinv1[i, i]

        # 量化当前的weight
        q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
        Q1[:, i] = q
        Losses1[:, i] = (w - q) ** 2 / d**2

        err1 = (w - q) / d
        # 在同一block内, 更新后续的weight
        W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
        Err1[:, i] = err1

    Q[:, i1:i2] = Q1
    Losses[:, i1:i2] = Losses1 / 2
    
    # 在不同blcok间更新后续weight
    W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])

上述代码实现是Auto_GPTQ中的实现,代码可分为两个部分,第一部分使用cholesky分解方法求得海森矩阵的逆;第二部分则根据scale和zero对w进行量化,依次得到err和loss,并据此依次更新后续的矩阵。注意这里我把group_size部分缩减了。
其计算公式为\(q = quantize(w.unsqueeze(1)); err1 = (w - q) / d; W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))\),这里d为海森逆矩阵对角线上的值,需要在同一block内和不同block间顺序更新后续权重参数。

quantizer中为普通的量化实现,quantize如下所示,会将原有的浮点数scale到对应的int8或in4区间。实现如下:

# 对称量化操作
scale = (xmax - xmin) / self.maxq
zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
# 非对称量化操作
zero = torch.round(-xmin / self.scale)

def quantize(x, scale, zero, maxq):
    if maxq < 0:
        return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
    q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
    return scale * (q - zero)

AWQ

权重对于LLM的性能并不同等重要,但要找到显著的权重通道,我们应该根据激活分布而不是权重分布,AWQ可以看作是smooth_quant的改进版。
自动搜索最优缩放,使全部权重下的量化误差最小,采用grid_search的方法对scale进行搜索,以保证最终Loss的损失值最小。
只测量每个通道的平均幅度误差来确定每个通道权重的重要性。

autoawq中的代码实现如下所示,所有过程分为4步:

  1. 计算每个channel weight的最大值
  2. 计算x的最大值
  3. 计算当前module的fp16输出
  4. 计算更新最大的scale
# [STEP 1]: Compute maximum of weight
weight = torch.cat([_m.weight for _m in layers], dim=0)
org_shape = weight.shape
weight = weight.view(-1, self.group_size)
w_scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
w_scale = w_scale.view(org_shape)
w_max = w_scale.mean(0)
clear_memory(weight)

# [STEP 2]: Compute maximum of x
x_max = inp.abs().view(-1, inp.shape[-1]).mean(0)

# [STEP 3]: Compute output of module
with torch.no_grad():
    fp16_output = module2inspect(inp, **kwargs)
    if isinstance(fp16_output, tuple):
        fp16_output = fp16_output[0]

# [STEP 4]: Compute loss
best_scales = self._compute_best_scale(
    inp, w_max, x_max, module2inspect,
    layers, fp16_output, kwargs
)

具体的网格搜索算法如下所示,将按照网格对[0, 1]内的数进行遍历scale,计算出使得L2(fp16 - quant_v)的最佳scale。

  1. 网格遍历scale,这里n_grid=20
  2. 按照smooth_quant中的方式计算scale,这里多了一步平滑
  3. weight *= scale;将伪量化算子插入到模型中
  4. 计算fp16和quant_v之间的误差,保留误差最小时的scale
for ratio in range(n_grid):
    # create new scales
    ratio = ratio / n_grid

    # NOTE: s^-1 * x is fused here, according to paper
    scales = (x_max.pow(ratio) / w_max.pow(1-ratio)).clamp(min=1e-4)
    scales = scales / (scales.max() * scales.min()).sqrt()
    scales_view = scales.view(1, -1).to(device)

    # Q(W * s)
    for fc in linears2scale:
        fc.weight.mul_(scales_view)
        fc.weight.data = self.pseudo_quantize_tensor(fc.weight.data) / scales_view

    # W * X
    int_w_output = module2inspect(x, **kwargs)
    if isinstance(int_w_output, tuple):
        int_w_output = int_w_output[0]
    
    # compute mean squared error (L2 norm)
    loss = (fp16_output - int_w_output).float().pow(2).mean().item() # NOTE: float prevents overflow

    history.append(loss)
    if loss < best_error:
        best_error = loss
        best_ratio = ratio
        best_scales = scales.clone()
posted @ 2024-06-29 12:50  wildkid1024  阅读(533)  评论(0编辑  收藏  举报