Efficiently Modeling Long Sequences with Structured State Spaces

Gu A., Goel K. and Re C. Efficiently modeling long sequences with structured state spaces. NeurIPS, 2022.

Mamba 系列第三作.

符号说明

  • \(u(t) \in \mathbb{R}\), 输入信号;
  • \(x(t) \in \mathbb{R}^N\), 中间状态;
  • \(y(t) \in \mathbb{R}\), 输出信号

S4

  • LSSL 中我们已经阐述了线性系统:

    \[x'(t) = A x(t) + Bu(t), \\ y(t) = C x(t) + D u(t) \]

    在兼顾 RNN, CNN 的优势的可能性, 并且离散化后说明 LSSL 实际上可以改写成卷积的形式, 从而实现高效的并行化:

    \[y = \mathcal{K}_L (\bar{A}, \bar{B}, C) * u + Du, \\ \mathcal{K}_L (A, B, C) := (CB, CAB, \ldots, CA^{L-1}B). \]

  • 现在的问题是, 如果 \(A\) 是固定的, 那么我们实际上只需要计算一次 \(\mathcal{K}_L\) 即可, 但是如果 \(A\) 不是固定的, 那么我们每次就需要付出额外的(相当多的)代价去计算 \(\mathcal{K}_L\), 其主要代价在于 \(A\).

  • 假设我们能够通过某个 \(V \in \mathbb{R}^{N \times N}\) 对角化 \(A\), 则我们有:

    \[\tilde{x}' = V^{-1} A V \tilde{x} + V^{-1} B u, \\ y = CV \tilde{x}. \]

    于是 \((V^{-1}AV)^{l}\) 计算起来就会比较方便了.

  • 但是问题是, 作者发现 HiPPO 矩阵的 \(V\) 的值的大小规模可以达到 \(2^{4N/3}\), 所以计算的时候会造成严重的数值问题.

  • S4 提出了一种改进方案:

    \[A = V(\Lambda - (V^*P) (V^*Q^*))V^*, \]

    其中

    \[P, Q \in \mathbb{R}^{N \times R}, \]

    为低秩矩阵.
    实际上可以证明, 对于所有的 HiPPO matrix, 都可以进行这样的分解.

  • 既然如此, S4 选择重参数化 \(A\)\((\Lambda \in \mathbb{R}^{N \times 1}, P \in \mathbb{R}^{N \times 1}, Q \in \mathbb{R}^{N \times 1})\), 以及 \(B, C \in \mathbb{R}^{N \times 1}\), 为 5N 的可训练参数.

注: 我看代码的时候, 感觉发现 \(V\) 是没有保留的, 所以直接就是采用 \(V\) 变换后的那个方程了 (我一开始以为会用 HiPPO matrix 的初始的 \(V\) 最后做个转换的, 实际上没有).

注: 作者没有提及 \(\Delta t\) 是否是训练的, 我感觉应该和 LSSL 一样可训练吧.

注: \(R=1\) 不是必须的, 代码里设置了参数可以调节.

代码

[official-code]

posted @ 2024-06-12 10:26  馒头and花卷  阅读(15)  评论(0编辑  收藏  举报