MLA机制原理及代码研究

Deepseek-R1现在火出圈了,效果好、成本低,让国人用上了第一梯队的AI。DeepSeek里有很多理论和工程上的创新点,但我认为最核心的,最原创的,是在Deepseek-V2时就提出来的MLA机制(Multi-head Latent Attention,多头隐含注意力)。本文详细走一遍它的数学和代码细节。

原理

要想深刻理解一个算法,一步步推导其数学原理是不可或缺的过程。

MHA

MLA是MHA(Multi Head Attention,多头注意力)的改进版。要理解MLA,首先要理解MHA。注意力机制又是什么意思?在日常的语境里,注意就是集中大部分资源(时间、心力)处理小部分重要的/相关的信息,而只用小部分资源处理其他大部分不重要/不想关的信息。 使用文档检索做比喻,就是给定一个query,找到那些与这个query高度相关的文档K,然后集中资源处理这些文档的内容V。

假设我们现在有一个长为s的序列,其中每个token被嵌入到一个维度d的空间里,那就得到了一个s×d的矩阵H,其每行代表一个token,每列代表一个嵌入维度(可以理解为传统机器学习中的特征)。我们把这个token矩阵H转换成3个矩阵,Q,K,V,代表每个token的query、key、和value。当处理到第t个token ht时,我们用这个token的query与所有token的key做相似度计算,也就是内积qtK 然后对其归一化ϕ(qtK)。通常ϕ(·)会使用Softmax(·),这样的出来的值会被理解为概率分布,也就是对于ht,我们应该关注其他每个token的概率值Softmax(qtK)。然后用的出来的注意力概率对每个token对应的value做加权平均,相当于更具重要度把各个token的value提取出来形成一个综合的token,a=Softmax(qtK)V,最后处理这个加权平均后的综合token o=WOa.

下面的数学公式用到的符号和写法都和我们上面的额讨论以及DeepSeek论文稍有不同。论文中只考虑一个token,且用列向量表示。这是经典线性代数的写法。对一个列向量样本做线性变换就是对它左乘一个变换矩阵。而这里我们考虑一整个序列的token,且用行向量表示每个token,列向量表示特征。这是机器学习领域的传统。对一个行向量样本做线性变换就是对其右乘一个变换矩阵。

Q,K,V=HWQ,HWK,HWVQ,K=posenc(Q),posenc(K)Q,K,V=[Q1;...,;Qm],[K1;...,;Km],[V1;...;Vm]Ai=imϕ(QiKid)ViO=[A1,...,Am]WO

通常QKV的维度和H一样,都是s×h。因为我们处理自然语言的句子时,每个词都有顺序,而且这些顺序往往很重要,邻近的词通常就比遥远的词更有主意理解当下在处理的词。比如说“我爱北京天安门”这句话,要理解“安门”这两个字,就不可能不结合它前一个字“天”,只有“天安门”三个字结合在一起才是一个有意义的词。而更前面的“北京”也提供了一个上下文环境,让我们更确定“天安门”三个字是指代着现实中那一座建筑。至于更遥远的“我爱”,则基本上可以忽略不计。所以我们需要把位置信息编码金给每个token向量里Q,K=posenc(Q),posenc(K)

注意我们并不直接拿QKV来计算,而是把每个token对应的qkv分隔成m份,分别处理。这样理解起来有点别扭,更自然的理解是我们把s×h维的矩阵H变换成s×hm维的QKV,然后重复做了m次。可以理解为这是从各不同方面去注意各个token,这也是“多头注意力”重“多头”的由来。另外,Q和K内积计算相似度后还会再除以一个缩放因子d,d=hm,再去做softmax归一化。这是在实验中发现当维度d很大的时候,QK的值都太大了,会造成softmax以后只有少部分值比较显著,其他大部分都为0,引发梯度消失。所以要把QK搞小一点再做softmax。而除以d则看起来很简单,且在实验上也表现不错。

以上的公式乍看之下有点复杂,我们只保留最核心的思想,忽略缩放因子d,忽略位置编码, 忽略多头,那么MHA的公式就可以简化成下面这样

Q,K,V=HWQ,HWK,HWVA=ϕ(QK)VO=AWO

MLA

MHA加持的Transformer模型架构表现出了极其强悍的能力。但是其计算量也很大,而且占用的GPU显存也不少,尤其是KV。假设我们要做一个翻译任务,输入一段英文,让模型翻译成中文。那么在第一阶段模型会一次性把这批英文文档切词(tokenize),变换成词矩阵H,一次性算出每个token的Q,K,V。第二阶段,模型会一个一个token往外吐,也就是先一侧下一个token,把新吐出来的token加儒道输入序列,然后预测下一个token。

如果我们的英文输入有10000个token,而token向量的维度是7168(Deepseek-V3的设置),那么第一阶段计算Q,K,V就要做3个(10000,7168)×(7168,7168)的矩阵乘法。第二阶段的时候计算第t个token时秩序算出ht对应的查询向量qt,但qt却需要跟序列前面所有token的k和v做运算。显然在第二阶段我们每要预测下一个token时都去算一遍前面所有token的KV,那样就太低效率,只能把这些KV不断缓存起来,以存代算。

但在上面的例子重一对KV矩阵就有100007168个元素,通常一个Transformer模型有几十层,比如Deepseek-V3有61层,而每次也不都只处理一个序列,而是一批序列,比如8个。随着模型不断往外输出token,序列长度不断增长,而KV缓存也在不断增长。如果模型翻译英文又输出了10000个token的中文,那么总体下来KV矩阵的元素就有(8200007168612)这么多。KV缓存通常是CudaOutOfMemory的直接原因。

我们希望尽可能减少KV缓存,以提升推理效率。因为GPU核心计算的速度比从内存读取数据到GPU的速度快得多。对于相同长度的序列来说,GPU处理一批8个序列比处理两批各4个序列所需的时间要短。如果我们减少KV缓存,那么GPU一批就能处理更多输入序列,而整体的吞吐效率就能有所提升。MLA要做的事情就是减少KV缓存,进而提升模型推理效率。

MLA的核心思想是不缓存K和V,而是先把K和V一起压缩成CKV,只缓存这个压缩后的矩阵。等到需要计算的时候,再解压缩回来。Deepseek-V3的设置中,K和V的维度都是7168,而CKV只有512,足足压缩了14倍。MLA的数学表达如下:

Q=HWQCKV=HWDKVK,V=CKVWUK,CKVWUVA=ϕ(QK)VO=AWO

在MLA的计算过程中,我们不再缓存KV,而是缓存更低维的CKV。话虽如此,然而整个计算过程看下来不过是对KV压缩然后解压缩而已,还多了一步计算,多了一个权重矩阵。而且最重要的是,解压缩过程K,V=CKVWUK,CKVWUV 相当于又做了一次矩阵乘法,还产生了两个大型中间结果矩阵KV,这俩中间结果占用的显存相当于缓存了原本的KV。而这也是HuggingFace上的官方源代码所做的事情

# https://huggingface.co/deepseek-ai/DeepSeek-V3/resolve/main/modeling_deepseek.py
class DeepseekV3Attention(nn.Module):
# ......
def __init__(self, ...):
# ......
self.kv_a_proj_with_mqa = nn.Linear(
self.hidden_size, # 7168
config.kv_lora_rank + config.qk_rope_head_dim, # 512 + 64 = 576
bias=config.attention_bias,
)
self.kv_b_proj = nn.Linear(
config.kv_lora_rank, # 512
# 128 * [(128+64) + 128] = 24576
self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
bias=False,
)
def forward(self, ...):
# ...
compressed_kv = self.kv_a_proj_with_mqa(hidden_states) # (batch_szie, seq_len, 576)
compressed_kv, k_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
) # (batch_szie, seq_len, 512), (batch_szie, seq_len, 64)
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
kv = (
self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) # (batch_szie, seq_len, 24576)
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
.transpose(1, 2)
) # (batch_szie, seq_len, 128, 64+128)
# ......
# 甚至连缓存的都是解压后的KV而不是解压前的compressed_kv
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)

也就是说,一番操作下来,计算量和缓存占用全都增加了。这也是为什么如果直接从huggingface拉下来跑的话,会发现它节省不了多少显存。想要得到解锁MLA节省显存的能力,必须使用矩阵吸收技巧,也就是论文里轻描淡写的那句:

during inference, since 𝑊𝑈𝐾 can be absorbed into 𝑊𝑄, and 𝑊𝑈𝑉 can be absorbed into 𝑊𝑂, we even do not need to compute keys and values out for attention.

要想知道它什么意思,我们需要在数学上把MLA展开:

O=AWO=ϕ(QK)VWO=ϕ[HWQ(CKVWUK)]CKVWUVWO=ϕ[H(WQWUK)CKV]CKV(WUVWO)

由此可见,WQWUK可以被合并,而WUVWO也可以被合并。这两个矩阵合并以后,对KV的整个计算过程都在低维空间进行,不会出现再把CKV解压缩回高维空间的情况。 况且,矩阵WQ,WUK,WUV,WO全都是模型的权重,再推理过程重是不会变的,可以看作常量。因此,如果是部署推理服务的话,再加载模型的时候就可以把这两个矩阵乘好,为以后的每次推理节省两次矩阵乘法。

现在一切看起来都很完美,但要记得上面的计算过程,为了方便我们都没有考虑到位置编码。然而位置编码是对型能影响很大的因素,我们必须考虑进去,然后就会发现一旦考虑了位置编码,那就和上面的矩阵合并不兼容了。也就是把token的位置信息编码到对应的q和k中。目前最好的位置编码方案是RoPE(Rotary Position Embedding,旋转位置编码)。其原理是对序列重的第t个token对应的q和k左乘一个旋转矩阵Rt

qt=RoPE(qt)=Rtwtkt=RoPE(kt)=Rtkt

在实际中我们当然不会为每个token ht搞出来一个矩阵Rt再做t次矩阵乘法,而是会根据RoPE的数学性质设计极其优化的算法,等价于数学生做t次矩阵乘。麻烦的点在于这些旋转矩阵Rt不像权重矩阵一样可以在部署推理前算好,而是在模型推理时才能算出来。如此一来,注意力的计算公式就变成了

O=ϕ[RoPE(HWQ)RoPE(WUKCKV)]CKV(WUVWO)

也就是WQ,WUK两个矩阵无法合并。

为了解决这个问题,DeepSeek团队各种探索,最终尝试出了一种方法:只对每个q、k向量的部分分量进行位置编码。也就是说:

Q=[Q1;Q2]=H·[W1Q,W2Q]K=[K1;K2]=CKV[W1UK,W2UK]Q2,K2=RoPE(Q2),RoPE(K2)Q,K=[Q1;Q2],[K1,K2]QK=[Q1;Q2]·[K1,K2]=Q1K1+Q2K2

如此一番操作以后

A=ϕ[QK]V=ϕ[Q1K1+Q2K2]V=ϕ[HW1Q(CKVW1UK)+RoPE(HW2Q)RoPE(CKVW2UK)]V=ϕ[H(W1QW1UK)CKV+RoPE(HW2Q)RoPE(CKVW2UK)]V

简单来说,DeepSeek团队做了一个妥协:把每个Q和K分割成两部分,一部分不做RoPE,一部分做RoPE,那么没做RoPE的部分就可以进行矩阵合并。在DeepSeek的配置文件中,前者和后者之比为512:64,也就是十分之一左右。经过这番改进,在缓存的时候除了要存CKV以外,还要缓存经过位置编码的部分key,也就是K2

实现

网上有不少具体的代码的实现。官方在Github上的代码就实现了缓存CKVK,以及WQ,WUK的矩阵合并,然而没有实现WUV,WO的合并,且在推理加载时也没有预先乘好存起来。推测是因为这是示例代码,以简洁明了为主吧(也可能时赶工,比较忙就没做)。

# https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
class MLA(nn.Module):
# ......
def forward(self, x: torch.Tensor, ...):
kv = self.wkv_a(x) # 压缩
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) # 把k分割成需要位置编码和不做位置编码的两份
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis) # 只对K的部分进行位置编码
self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv) # 缓存压缩后的$C^{KV}$
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2) # 缓存位置编码后的部分key $K_p$
# Q'K'^{\top} = [Q_1; Q'_2] · [K_1; K'_2]^{\top} = Q_1K_1^{\top} + Q'_2K'_2
scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
# O = A · C^{KV} · W^{UV} · W^O
x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
x = self.wo(x.flatten(2))

官方的Github代码清晰简洁,而最大的缺憾是并没有采用HuggingFace的编程框架。如果想要跟HuggingFace结合使用,可以考虑清华的KTransformers,但Ktransformers加入了一些其他功能,它们的类也魔改过,不能无缝用在HuggingFace框架里,但需要做些修改。不过他们为Deepseek-V2-Chat打了个补丁,兼容了Huggingface(但并没有被合并到主干分支中去):

# https://huggingface.co/deepseek-ai/DeepSeek-V2-Chat/blob/refs%2Fpr%2F12/modeling_deepseek.py
class DeepseekV2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
# ......
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size()
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, k_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
compressed_kv = self.kv_a_layernorm(compressed_kv)
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
kv_seq_len = k_pe.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(q_pe, seq_len=kv_seq_len)
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
compressed_kv = compressed_kv.unsqueeze(1)
# 只缓存压缩后的kv和位置编码后的部分k
k_pe, compressed_kv = past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs)
compressed_kv = compressed_kv.squeeze(1)
kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim,:]
out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :]
q_nope = torch.matmul(q_nope, q_absorb)
attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT)) * self.softmax_scale
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
assert attention_mask is not None
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(q_pe.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.attention_dropout, training=self.training
)
attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv)
attn_output = torch.matmul(attn_output, out_absorb.mT)
if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value

结语

MLA对我来说是很有启发性的工作,而且我还没有彻底理解它。最大的问题,GQA(Grouped Query Attention,分组查询注意力)也是一种KV压缩机制,也可以做到和MLA一样的压缩效率,但为什么MLA就是不会造成GQA那用的性能损失,反而还会带来一点提升?直觉上我觉得跟引入了额外的计算和矩阵变换(WDKV,WUK,WUV)有关,但我说不清其中的原理。

另外,MLA关于位置编码和矩阵吸收的思想,似乎对于MHA也是惯用的。最另外惊奇的是,MLA只对Q和K的小部分(十分之一)进行位置编码就已经足够,那么在MHA里是不是也能这么做?如果是的话,我们也可以把矩阵WQWK合并,把WOWV合并,从而减少两个矩阵乘法的计算量。

最后,除了Deepseek以外,国内其他AI团队也做出了很出彩的成果,比如Moonshot也用强化学习做出了一个强推里模型Kimi-1.5,Minimax的lightening attention也做到很快的推理速度,还将上下文长度干到了1024万token。这些都是很有价值的成就,不过Deepseek先声夺人,彻底开源,且效果极好,所以掩盖了他们的光芒。但他们不能埋没,我会在以后继续研究它们。另外,DeepSeek也在继续放出新的成果,比如最近的NAS(Native Sparse Attention,原生稀疏注意力)。吾辈当自强,上下而求索。


参考:

posted @   zrq96  阅读(311)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 开源Multi-agent AI智能体框架aevatar.ai,欢迎大家贡献代码
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
点击右上角即可分享
微信分享提示