【杂学】大模型推理加速 —— KV-cache 技术

如果不熟悉 Transformer 的同学可以点击这里了解

自从《Attention Is All You Need》问世以来,Transformer 已经成为了 LLM 中最基础的架构,被广泛使用。KV-cache 是大模型推理加速的关键技术之一,已经成为了 Transformer 标配的功能,不过其只能用于 Decoder 结构:由于 Decoder 中有 Mask 机制,推理的时候前面的词不需要与后面的词计算 attention score,因此 KV 矩阵可以被缓存起来用于多次计算。

例如,我们要生成 "I love Tianjin University" 这句话。首先是只有开头标记 <s>,计算过程如下图所示(为了便于理解,我们将 softmax 和 scale 去掉):

image

最终第一步的注意力 \(\text{Att}_{step1}\) 计算公式为:

\[{\color{red}\text{Att}_1}(Q,K,V)=({\color{red}{Q_1}}K_1^T)\overrightarrow{V_1} \]

此时序列中的词为 "<s> I",由于有 Mask 机制,第二步计算如下图所示:
image

第二步的注意力为:

\[\begin{aligned} \text{Att}_{step2}&=\begin{bmatrix}{\color{red}{Q_1}}K_{1}^{T}&0\\{\color{green}{Q_2}}K_{1}^{T}&{\color{green}{Q_2}}K_{2}^{T}\end{bmatrix}\begin{bmatrix}\overrightarrow{V_{1}}\\\overrightarrow{V_{2}}\end{bmatrix} =\begin{bmatrix}{\color{red}{Q_1}}K_1^T\times\overrightarrow{V1}\\{\color{green}{Q_2}}K_1^T\times\overrightarrow{V1}+{\color{green}{Q_2}}K_2^T\times\overrightarrow{V2}\end{bmatrix} \end{aligned}\]

\(\text{Att}_1\) 是第一行,\(\text{Att}_2\) 是第二行,则有:

\[\begin{aligned}&{\color{red}\text{Att}_1}(Q,K,V)={\color{red}{Q_1}}K_1^T\overrightarrow{V_1}\\&{\color{green}\text{Att}_2}(Q,K,V)={\color{green}{Q_2}}K_1^T\overrightarrow{V_1}+{\color{green}{Q_2}}K_2^T\overrightarrow{V_2}\end{aligned} \]

此时我们可以大胆猜想:

  • \(\text{Att}_k\) 只与 \(Q_k\) 有关
  • 已经计算出的 \(\text{Att}\) 永远都不会改变

带着这个猜想,继续生成下面的词,容易计算第三步可以得到:

image

\[\begin{aligned}&{\color{red}\text{Att}_1}(Q,K,V)={\color{red}{Q_1}}K_1^T\overrightarrow{V_1}\\&{\color{green}\text{Att}_2}(Q,K,V)={\color{green}{Q_2}}K_1^T\overrightarrow{V_1}+{\color{green}{Q_2}}K_2^T\overrightarrow{V_2}\\&{\color{blue}\text{Att}_3}(Q,K,V)={\color{blue}{Q_3}}K_1^T\overrightarrow{V_1}+{\color{blue}{Q_3}}K_2^T\overrightarrow{V_2}+{\color{blue}{Q_3}}K_3^T\overrightarrow{V_3}\end{aligned} \]

同样的,\(\text{Att}_k\) 只与 \(Q_k\) 有关。第四步也相同。

看上面的图和公式,我们可以归纳出性质:

  1. 朴素的 Attention 计算存在大量冗余
  2. \(\text{Att}_k\) 只与 \(Q_k\) 有关,即预测词 \(x_k\) 仅依赖于 \(x_{k-1}\)
  3. \(K\)\(V\) 全程参与计算,可以缓存起来
  4. 虽然叫做 KV-cache,但其实真正优化掉的是冗余的 \(Q\)\(\text{Att}\)

当然,这有点类似于动态规划思想,也存在利用空间换取时间的问题,因此当序列很长时,KV-cache 有可能会出现内存爆炸的情况。

下面附上 gpt 的 KV-cache 代码,非常简单,仅仅是做了 concat 操作。不过值得注意的是,attention 的计算并没有使用 cache。

if layer_past is not None:
        past_key, past_value = layer_past
        key = torch.cat((past_key, key), dim=-2)
        value = torch.cat((past_value, value), dim=-2)
    
    if use_cache is True:
        present = (key, value)
    else:
        present = None
    
    if self.reorder_and_upcast_attn:
        attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
    else:
        attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
posted @ 2024-11-13 20:47  KeanShi  阅读(66)  评论(0编辑  收藏  举报