Fast Transformer Decoding: One Write-Head is All You Need论文阅读笔记(MQA)

Motivation & Abs

增量推理对于MHA是非常慢的(难以并行),因为重复加载大的键/值会增大内存带宽的开销。为此作者提出了multi-query attention(MQA),其中不同注意力头共享相同的键和值,减小了增量解码的内存带宽要求。MQA可以大幅提升解码的速度,同时推理质量仅有略微下降。

Method

Multihead Attention (Incremental)

截屏2024-12-14 15.31.52

这里的计算次数是Θ(bnd2),因为作者进行了如下简化假设:

m=n

k=v=dh

nd

带入可得使用einsum计算q的运算次数为bdhk=bdhdh=bd2,同理new_K和new_V也是同样的运算次数。计算logits的运算次数为bhmk=bhndh=bndbd2,计算o的次数为bhmv=bhndh=bndbd2,计算y的次数为bhvd=bhdhd=bd2。因此进行n次函数调用的运算次数为Θ(bnd2)

进行n次函数调用的内存访问为Θ(bn2d+nd2)。第一项来自K和V,第二项来自P矩阵。

通过以上观察可以发现,内存访问次数与算数运算次数的比率为Θ(nd+1b),当nd或者d1,瓶颈就变成了内存带宽。为了让增量推理更加高效,必须让这个比例远小于1。减小第二项只需要增加batch size,但渐小第一项比较困难。本文提出了一种方法,删除K和V的heads维度,同时在Q中保留这一维度。

Multi-Query Attention

MQA可以视作MHA的变体,不同的 head 共享一组键和值:

截屏2024-12-14 16.22.54

截屏2024-12-14 16.23.38

这样可以将KV cache的大小变为1h,非常可观。

posted @   脂环  阅读(71)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 单线程的Redis速度为什么快?
· SQL Server 2025 AI相关能力初探
· AI编程工具终极对决:字节Trae VS Cursor,谁才是开发者新宠?
· 展开说说关于C#中ORM框架的用法!
历史上的今天:
2020-12-14 第 45 届国际大学生程序设计竞赛(ICPC)亚洲区域赛(上海)D. Walker(二分/分类讨论)
2020-12-14 第 45 届国际大学生程序设计竞赛(ICPC)亚洲区域赛(上海)M . Gitignore(模拟)
2020-12-14 第 45 届国际大学生程序设计竞赛(ICPC)亚洲区域赛(上海)I.Sky Garden(几何/思维)
点击右上角即可分享
微信分享提示
主题色彩