Fast Transformer Decoding: One Write-Head is All You Need论文阅读笔记(MQA)
Motivation & Abs
增量推理对于MHA是非常慢的(难以并行),因为重复加载大的键/值会增大内存带宽的开销。为此作者提出了multi-query attention(MQA),其中不同注意力头共享相同的键和值,减小了增量解码的内存带宽要求。MQA可以大幅提升解码的速度,同时推理质量仅有略微下降。
Method
Multihead Attention (Incremental)
这里的计算次数是\(\Theta(bnd^2)\),因为作者进行了如下简化假设:
\(m=n\)
\(k=v=\frac{d}{h}\)
\(n\leq d\)
带入可得使用einsum计算q的运算次数为\(bdhk=bdh\frac{d}{h}=bd^2\),同理new_K和new_V也是同样的运算次数。计算logits的运算次数为\(bhmk=bhn\frac{d}{h}=bnd\leq bd^2\),计算o的次数为\(bhmv=bhn\frac{d}{h}=bnd\leq bd^2\),计算y的次数为\(bhvd=bh\frac{d}{h}d=bd^2\)。因此进行n次函数调用的运算次数为\(\Theta(bnd^2)\)。
进行n次函数调用的内存访问为\(\Theta(bn^2d+nd^2)\)。第一项来自K和V,第二项来自P矩阵。
通过以上观察可以发现,内存访问次数与算数运算次数的比率为\(\Theta(\frac{n}{d}+\frac{1}{b})\),当\(n\approx d\)或者\(d\approx 1\),瓶颈就变成了内存带宽。为了让增量推理更加高效,必须让这个比例远小于1。减小第二项只需要增加batch size,但渐小第一项比较困难。本文提出了一种方法,删除K和V的heads维度,同时在Q中保留这一维度。
Multi-Query Attention
MQA可以视作MHA的变体,不同的 head 共享一组键和值:
这样可以将KV cache的大小变为\(\frac{1}{h}\),非常可观。