MLA机制原理及代码研究
Deepseek-R1现在火出圈了,效果好、成本低,让国人用上了第一梯队的AI。DeepSeek里有很多理论和工程上的创新点,但我认为最核心的,最原创的,是在Deepseek-V2时就提出来的MLA机制(Multi-head Latent Attention,多头隐含注意力)。本文详细走一遍它的数学和代码细节。
原理
要想深刻理解一个算法,一步步推导其数学原理是不可或缺的过程。
MHA
MLA是MHA(Multi Head Attention,多头注意力)的改进版。要理解MLA,首先要理解MHA。注意力机制又是什么意思?在日常的语境里,注意就是集中大部分资源(时间、心力)处理小部分重要的/相关的信息,而只用小部分资源处理其他大部分不重要/不想关的信息。 使用文档检索做比喻,就是给定一个query,找到那些与这个query高度相关的文档K,然后集中资源处理这些文档的内容V。
假设我们现在有一个长为的序列,其中每个token被嵌入到一个维度的空间里,那就得到了一个的矩阵,其每行代表一个token,每列代表一个嵌入维度(可以理解为传统机器学习中的特征)。我们把这个token矩阵转换成3个矩阵,,代表每个token的query、key、和value。当处理到第个token 时,我们用这个token的query与所有token的key做相似度计算,也就是内积 然后对其归一化。通常会使用,这样的出来的值会被理解为概率分布,也就是对于,我们应该关注其他每个token的概率值。然后用的出来的注意力概率对每个token对应的value做加权平均,相当于更具重要度把各个token的value提取出来形成一个综合的token,,最后处理这个加权平均后的综合token .
下面的数学公式用到的符号和写法都和我们上面的额讨论以及DeepSeek论文稍有不同。论文中只考虑一个token,且用列向量表示。这是经典线性代数的写法。对一个列向量样本做线性变换就是对它左乘一个变换矩阵。而这里我们考虑一整个序列的token,且用行向量表示每个token,列向量表示特征。这是机器学习领域的传统。对一个行向量样本做线性变换就是对其右乘一个变换矩阵。
通常QKV的维度和H一样,都是。因为我们处理自然语言的句子时,每个词都有顺序,而且这些顺序往往很重要,邻近的词通常就比遥远的词更有主意理解当下在处理的词。比如说“我爱北京天安门”这句话,要理解“安门”这两个字,就不可能不结合它前一个字“天”,只有“天安门”三个字结合在一起才是一个有意义的词。而更前面的“北京”也提供了一个上下文环境,让我们更确定“天安门”三个字是指代着现实中那一座建筑。至于更遥远的“我爱”,则基本上可以忽略不计。所以我们需要把位置信息编码金给每个token向量里。
注意我们并不直接拿QKV来计算,而是把每个token对应的qkv分隔成份,分别处理。这样理解起来有点别扭,更自然的理解是我们把维的矩阵变换成维的QKV,然后重复做了次。可以理解为这是从各不同方面去注意各个token,这也是“多头注意力”重“多头”的由来。另外,Q和K内积计算相似度后还会再除以一个缩放因子,再去做softmax归一化。这是在实验中发现当维度很大的时候,的值都太大了,会造成softmax以后只有少部分值比较显著,其他大部分都为0,引发梯度消失。所以要把搞小一点再做softmax。而除以则看起来很简单,且在实验上也表现不错。
以上的公式乍看之下有点复杂,我们只保留最核心的思想,忽略缩放因子,忽略位置编码, 忽略多头,那么MHA的公式就可以简化成下面这样
MLA
MHA加持的Transformer模型架构表现出了极其强悍的能力。但是其计算量也很大,而且占用的GPU显存也不少,尤其是KV。假设我们要做一个翻译任务,输入一段英文,让模型翻译成中文。那么在第一阶段模型会一次性把这批英文文档切词(tokenize),变换成词矩阵,一次性算出每个token的。第二阶段,模型会一个一个token往外吐,也就是先一侧下一个token,把新吐出来的token加儒道输入序列,然后预测下一个token。
如果我们的英文输入有10000个token,而token向量的维度是7168(Deepseek-V3的设置),那么第一阶段计算就要做3个的矩阵乘法。第二阶段的时候计算第个token时秩序算出对应的查询向量,但却需要跟序列前面所有token的k和v做运算。显然在第二阶段我们每要预测下一个token时都去算一遍前面所有token的KV,那样就太低效率,只能把这些KV不断缓存起来,以存代算。
但在上面的例子重一对KV矩阵就有个元素,通常一个Transformer模型有几十层,比如Deepseek-V3有61层,而每次也不都只处理一个序列,而是一批序列,比如8个。随着模型不断往外输出token,序列长度不断增长,而KV缓存也在不断增长。如果模型翻译英文又输出了10000个token的中文,那么总体下来KV矩阵的元素就有这么多。KV缓存通常是CudaOutOfMemory
的直接原因。
我们希望尽可能减少KV缓存,以提升推理效率。因为GPU核心计算的速度比从内存读取数据到GPU的速度快得多。对于相同长度的序列来说,GPU处理一批8个序列比处理两批各4个序列所需的时间要短。如果我们减少KV缓存,那么GPU一批就能处理更多输入序列,而整体的吞吐效率就能有所提升。MLA要做的事情就是减少KV缓存,进而提升模型推理效率。
MLA的核心思想是不缓存K和V,而是先把K和V一起压缩成,只缓存这个压缩后的矩阵。等到需要计算的时候,再解压缩回来。Deepseek-V3的设置中,K和V的维度都是7168,而只有512,足足压缩了14倍。MLA的数学表达如下:
在MLA的计算过程中,我们不再缓存KV,而是缓存更低维的。话虽如此,然而整个计算过程看下来不过是对KV压缩然后解压缩而已,还多了一步计算,多了一个权重矩阵。而且最重要的是,解压缩过程 相当于又做了一次矩阵乘法,还产生了两个大型中间结果矩阵,这俩中间结果占用的显存相当于缓存了原本的KV。而这也是HuggingFace上的官方源代码所做的事情
# https://huggingface.co/deepseek-ai/DeepSeek-V3/resolve/main/modeling_deepseek.py class DeepseekV3Attention(nn.Module): # ...... def __init__(self, ...): # ...... self.kv_a_proj_with_mqa = nn.Linear( self.hidden_size, # 7168 config.kv_lora_rank + config.qk_rope_head_dim, # 512 + 64 = 576 bias=config.attention_bias, ) self.kv_b_proj = nn.Linear( config.kv_lora_rank, # 512 # 128 * [(128+64) + 128] = 24576 self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), bias=False, ) def forward(self, ...): # ... compressed_kv = self.kv_a_proj_with_mqa(hidden_states) # (batch_szie, seq_len, 576) compressed_kv, k_pe = torch.split( compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 ) # (batch_szie, seq_len, 512), (batch_szie, seq_len, 64) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) kv = ( self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) # (batch_szie, seq_len, 24576) .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) .transpose(1, 2) ) # (batch_szie, seq_len, 128, 64+128) # ...... # 甚至连缓存的都是解压后的KV而不是解压前的compressed_kv if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs )
也就是说,一番操作下来,计算量和缓存占用全都增加了。这也是为什么如果直接从huggingface拉下来跑的话,会发现它节省不了多少显存。想要得到解锁MLA节省显存的能力,必须使用矩阵吸收技巧,也就是论文里轻描淡写的那句:
during inference, since can be absorbed into , and can be absorbed into , we even do not need to compute keys and values out for attention.
要想知道它什么意思,我们需要在数学上把MLA展开:
由此可见,和可以被合并,而和也可以被合并。这两个矩阵合并以后,对KV的整个计算过程都在低维空间进行,不会出现再把解压缩回高维空间的情况。 况且,矩阵全都是模型的权重,再推理过程重是不会变的,可以看作常量。因此,如果是部署推理服务的话,再加载模型的时候就可以把这两个矩阵乘好,为以后的每次推理节省两次矩阵乘法。
现在一切看起来都很完美,但要记得上面的计算过程,为了方便我们都没有考虑到位置编码。然而位置编码是对型能影响很大的因素,我们必须考虑进去,然后就会发现一旦考虑了位置编码,那就和上面的矩阵合并不兼容了。也就是把token的位置信息编码到对应的q和k中。目前最好的位置编码方案是RoPE(Rotary Position Embedding,旋转位置编码)。其原理是对序列重的第t个token对应的q和k左乘一个旋转矩阵
在实际中我们当然不会为每个token 搞出来一个矩阵再做t次矩阵乘法,而是会根据RoPE的数学性质设计极其优化的算法,等价于数学生做t次矩阵乘。麻烦的点在于这些旋转矩阵不像权重矩阵一样可以在部署推理前算好,而是在模型推理时才能算出来。如此一来,注意力的计算公式就变成了
也就是两个矩阵无法合并。
为了解决这个问题,DeepSeek团队各种探索,最终尝试出了一种方法:只对每个q、k向量的部分分量进行位置编码。也就是说:
如此一番操作以后
简单来说,DeepSeek团队做了一个妥协:把每个Q和K分割成两部分,一部分不做RoPE,一部分做RoPE,那么没做RoPE的部分就可以进行矩阵合并。在DeepSeek的配置文件中,前者和后者之比为,也就是十分之一左右。经过这番改进,在缓存的时候除了要存以外,还要缓存经过位置编码的部分key,也就是。
实现
网上有不少具体的代码的实现。官方在Github上的代码就实现了缓存和,以及的矩阵合并,然而没有实现的合并,且在推理加载时也没有预先乘好存起来。推测是因为这是示例代码,以简洁明了为主吧(也可能时赶工,比较忙就没做)。
# https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py class MLA(nn.Module): # ...... def forward(self, x: torch.Tensor, ...): kv = self.wkv_a(x) # 压缩 kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) # 把k分割成需要位置编码和不做位置编码的两份 k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis) # 只对K的部分进行位置编码 self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv) # 缓存压缩后的$C^{KV}$ self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2) # 缓存位置编码后的部分key $K_p$ # Q'K'^{\top} = [Q_1; Q'_2] · [K_1; K'_2]^{\top} = Q_1K_1^{\top} + Q'_2K'_2 scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) + torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x) # O = A · C^{KV} · W^{UV} · W^O x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos]) x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:]) x = self.wo(x.flatten(2))
官方的Github代码清晰简洁,而最大的缺憾是并没有采用HuggingFace的编程框架。如果想要跟HuggingFace结合使用,可以考虑清华的KTransformers,但Ktransformers加入了一些其他功能,它们的类也魔改过,不能无缝用在HuggingFace框架里,但需要做些修改。不过他们为Deepseek-V2-Chat打了个补丁,兼容了Huggingface(但并没有被合并到主干分支中去):
# https://huggingface.co/deepseek-ai/DeepSeek-V2-Chat/blob/refs%2Fpr%2F12/modeling_deepseek.py class DeepseekV2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" # ...... def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) bsz, q_len, _ = hidden_states.size() q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) q_nope, q_pe = torch.split( q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv, k_pe = torch.split( compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 ) compressed_kv = self.kv_a_layernorm(compressed_kv) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) kv_seq_len = k_pe.shape[-2] if past_key_value is not None: if self.layer_idx is None: raise ValueError( f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(q_pe, seq_len=kv_seq_len) q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models compressed_kv = compressed_kv.unsqueeze(1) # 只缓存压缩后的kv和位置编码后的部分k k_pe, compressed_kv = past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs) compressed_kv = compressed_kv.squeeze(1) kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank) q_absorb = kv_b_proj[:, :self.qk_nope_head_dim,:] out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :] q_nope = torch.matmul(q_nope, q_absorb) attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT)) * self.softmax_scale if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" f" {attn_weights.size()}" ) assert attention_mask is not None if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax( attn_weights, dim=-1, dtype=torch.float32 ).to(q_pe.dtype) attn_weights = nn.functional.dropout( attn_weights, p=self.attention_dropout, training=self.training ) attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv) attn_output = torch.matmul(attn_output, out_absorb.mT) if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value
结语
MLA对我来说是很有启发性的工作,而且我还没有彻底理解它。最大的问题,GQA(Grouped Query Attention,分组查询注意力)也是一种KV压缩机制,也可以做到和MLA一样的压缩效率,但为什么MLA就是不会造成GQA那用的性能损失,反而还会带来一点提升?直觉上我觉得跟引入了额外的计算和矩阵变换()有关,但我说不清其中的原理。
另外,MLA关于位置编码和矩阵吸收的思想,似乎对于MHA也是惯用的。最另外惊奇的是,MLA只对Q和K的小部分(十分之一)进行位置编码就已经足够,那么在MHA里是不是也能这么做?如果是的话,我们也可以把矩阵和合并,把和合并,从而减少两个矩阵乘法的计算量。
最后,除了Deepseek以外,国内其他AI团队也做出了很出彩的成果,比如Moonshot也用强化学习做出了一个强推里模型Kimi-1.5,Minimax的lightening attention也做到很快的推理速度,还将上下文长度干到了1024万token。这些都是很有价值的成就,不过Deepseek先声夺人,彻底开源,且效果极好,所以掩盖了他们的光芒。但他们不能埋没,我会在以后继续研究它们。另外,DeepSeek也在继续放出新的成果,比如最近的NAS(Native Sparse Attention,原生稀疏注意力)。吾辈当自强,上下而求索。
参考:
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 开源Multi-agent AI智能体框架aevatar.ai,欢迎大家贡献代码
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY