Mamba: Linear-Time Sequence Modeling with Selective State Spaces

Gu A. and Dao T. Mamba: Linear-time sequence modeling with selective state spaces. 2023.

Mamba.

Mamba

  • S4S4D 虽然解决了 SSM 计算速度的问题, 但是有一个前提, 就是 \(A, B, C, D\) 是与时间 \(t\) 无关的. 这导致这些方法只能采取一种固定的模式取处理序列问题, 作者认为这导致 SSM 无法 text 这类强上下文关系的任务.

  • 所以如上图和上述算法所示, \(B, C, \Delta\) 现在是与输入有关的了, 不同的输入会产生不同的 \(B, C, \Delta\).

  • 但是, 我们知道 \(\bar{A}\)\(A, \Delta t\) 共同决定, 这就导致 \(\bar{A}(x)\) 实际上也是与输入有关的了.

  • 而我们知道, S4, S4D 训练速度快的原因就是输出能够通过卷积的方式实现:

    \[y = \mathcal{K}_L (\bar{A}, \bar{B}, C) * u + Du, \\ \mathcal{K}_L (A, B, C) := (CB, CAB, \ldots, CA^{L-1}B). \]

    但是这个必须要求 \(A\) 是随着 \(t\) 不变的, 所以我们没法实现这一点.

  • 所以作者额外设计了 scan 算法, 如上图所示, 这是一种 hardware-aware 的算法, 他会把隐状态的更新放在 GPU 中速度最快的 SRAM 位置, 我看网上大多用下面这个图来说:

  • 我对这个不太感兴趣, 有兴趣的同学可以找相应的博客看看.

代码

[official-code]

posted @ 2024-06-12 20:31  馒头and花卷  阅读(29)  评论(0编辑  收藏  举报