LLM 加速技巧:Muti Query Attention
前言 MQA 是 19 年提出的一种新的 Attention 机制,其能够在保证模型效果的同时加快 decoder 生成 token 的速度。在大语言模型时代被广泛使用,很多LLM都采用了MQA,如Falcon、PaLM、StarCoder等。
本文转载自Deephub Imba
作者:Florian June
仅用于学术分享,若侵权请联系删除
欢迎关注公众号CV技术指南,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读、CV招聘信息。
【CV技术指南】CV全栈指导班、基础入门班、论文指导班 全面上线!!
在介绍MQA 之前,我们先回顾一下传统的多头注意力
Multi-Head Attention(MHA)
多头注意力是transformer 模型的默认注意力机制,如下图所示:
在文本生成方面,基于transformer 的自回归语言模型存在一个问题。在训练过程中可以获得真实的目标序列,并且可以有效地实现并行化。
但是在推理过程中,每个位置的查询都要处理在该位置或之前生成的所有键值对。也就是说自注意力层在特定位置的输出影响下一个令牌的生成,所以无法并行化,这使得推理变得非常的慢。
下图是基于transformer 解码器的自回归语言模型中自注意层的解码过程:
def MHAForDecoder(x, prev_K, prev_V, P_q, P_k, P_v, P_o):
q = tf.einsum("bd, hdk−>bhk", x, P_q)
new_K = tf.concat([prev_K, tf.expand_dims(tf.einsum ("bd, hdk−>bhk", x, P_k), axis = 2)], axis = 2)
new_V = tf.concat([prev_V, tf.expand_dims(tf.einsum("bd, hdv−>bhv", x, P_v), axis = 2)], axis = 2)
logits = tf.einsum("bhk, bhmk−>bhm", q, new_K)
weights = tf.softmax(logits)
O = tf.einsum("bhm, bhmv−>bhv", weights, new_V)
Y = tf.einsum("bhv, hdv−>bd", O, P_o)
return Y, new_K, new_V
其中:
X:当前的输入张量,m为当前步,m+1为阶跃,形状为[b, d]
P_q, P_k:查询和键投影张量,形状为[h, d, k]
P_v:值投影张量,形状为[h, d, v]
P_o:学习到的线性投影,形状为[h, d, v]
Prev_K:上一步的关键张量,形状为[b, h, m, k]
Prev_V:前一步的Value张量,形状为[b, h, m, v]
new_K:加上当前步的键张量,形状为[b, h, m+1, k]
new_V:加了当前步长的Value张量,形状为[b, h, m+1, v]
维度表示如下:
M:先前执行的步骤数
B:批量大小
D:输入和输出的尺寸
H:注意力头数
k:Q,K张量的另一个维度
v: v张量的另一个维度
Multi-Query Attention(MQA)
MQA是多头注意的一种变体。
MQA的方法是保持Q的初始头数,但K和V只有一个头,这意味着所有Q个头共享相同的K和V,因此称为Multi-Query,如下图所示:
从论文的解释中可以看到,MQA 让所有的头之间 共享 同一份 Key 和 Value 矩阵,每个头只单独保留了一份 Query 参数,从而大大减少 Key 和 Value 矩阵的参数量。
MQA解码过程的代码本质上与MHA的代码相同,只是从中删除了表示头部尺寸的字母“h”。K, V, P_k, P_v的和方程:
def MQAForDecoder(x, prev_K, prev_V, P_q, P_k, P_v, P_o):
q = tf.einsum("bd, hdk−>bhk", x, P_q)
new_K = tf.concat([prev_K, tf.expand_dims(tf.einsum ("bd, dk−>bk", x, P_k), axis = 2)], axis = 2)
new_V = tf.concat([prev_V, tf.expand_dims(tf.einsum("bd, dv−>bv", x, P_v), axis = 2)], axis = 2)
logits = tf.einsum("bhk, bmk−>bhm", q, new_K)
weights = tf.softmax(logits)
O = tf.einsum("bhm, bmv−>bhv", weights, new_V)
Y = tf.einsum("bhv, hdv−>bd", O, P_o)
return Y, new_K, new_V
上面都是tf的代码,如果阅读有问题,我从 llm-foundry项目中找到了pytorch的代码实现,这里只做个摘抄,有兴趣的请看原项目
class MultiheadAttention(nn.Module):
def __init__(
self,
d_model: int,
n_heads: int,
device: str
):
"""
Multi Head init func.
Args:
d_model (int): hidden state size, e.g. 768
n_heads (int): 设定的注意力头数, e.g. 8
device (str): _description_
"""
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.Wqkv = nn.Linear( # Multi-Head Attention 的创建方法
self.d_model,
3 * self.d_model, # 有 query, key, value 3 个矩阵, 所以是 3 * d_model
device=device
) # (d_model, 3 * d_model)
self.attn_fn = scaled_multihead_dot_product_attention
self.out_proj = nn.Linear(
self.d_model,
self.d_model,
device=device
)
def forward(
self,
x
):
"""
forward func.
Args:
x (tensor): (batch, hidden_state, d_model) e.g. -> (1, 768, 512)
Returns:
_type_: _description_
"""
qkv = self.Wqkv(x) # (1, 768, 3 * 768)
query, key, value = qkv.chunk( # 每个 tensor 都是 (1, 512, 768)
3,
dim=2
)
context, attn_weights, past_key_value = self.attn_fn(
query,
key,
value,
self.n_heads
) # (1, 512, 768)
return self.out_proj(context), attn_weights, past_key_value
class MultiQueryAttention(nn.Module):
"""Multi-Query self attention.
Using torch or triton attention implemetation enables user to also use
additive bias.
"""
def __init__(
self,
d_model: int,
n_heads: int,
device: Optional[str] = None,
):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.Wqkv = nn.Linear( # Multi-Query Attention 的创建方法
d_model,
d_model + 2 * self