On the Parameterization and Initialization of Diagonal State Space Models

Gu A., Gupta A., Goel K. and Re C. On the parameterization and initialization of diagonal state space models. NeurIPS, 2022.

Mamba 系列第四作: S4D.

符号说明

  • \(u(t) \in \mathbb{R}\), 输入信号;
  • \(x(t) \in \mathbb{R}^N\), 中间状态;
  • \(y(t) \in \mathbb{R}\), 输出信号

S4D

  • LSSL 中我们已经阐述了线性系统:

    \[x'(t) = A x(t) + Bu(t), \\ y(t) = C x(t) + Du(t) \]

    在兼顾 RNN, CNN 的优势的可能性, 并且离散化后说明 LSSL 实际上可以改写成卷积的形式, 从而实现高效的并行化:

    \[y = \mathcal{K}_L (\bar{A}, \bar{B}, C) * u + Du, \\ \mathcal{K}_L (A, B, C) := (CB, CAB, \ldots, CA^{L-1}B). \]

  • S4 的初衷是对角化 \(A\) 来避免卷积过程中 \(A^l\) 的复杂运算, 不过考虑到完全对角化的一个数值问题, 最终 S4 给出的策略是重参数化 \(A\) 为对角矩阵 + 低秩矩阵.

  • 不过最近 DSS 发现通过合理的初始化, 就能够避免数值问题, 本文在此基础上进一步探索.

  • 首先作者考虑简化的 ODE:

    \[x'(t) = A x(t) + Bu(t), \\ y(t) = C x(t), \]

    并给出一个等价的卷积形式:

    \[K(t) = C e^{tA} B, \\ y(t) = (K * u)(t). \]


proof:

  • 首先注意到:

    \[\begin{array}{ll} & x'(t) = Ax(t) + Bu(t) \\ \Rightarrow & x'(t) - Ax(t) = Bu(t) \\ \Rightarrow & e^{-tA} x'(t) - e^{-tA} Ax(t) = e^{-tA} Bu(t) \\ \Rightarrow & (e^{-tA} x(t))' = e^{-tA} Bu(t) \\ \Rightarrow & e^{-tA} x(t) = e^{-tA}x(0) + \int_0^t e^{-\tau A} Bu(\tau) d \tau \\ \Rightarrow & x(t) = x(0) + \int_0^t e^{(t-\tau) A} Bu(\tau) d \tau \\ \Rightarrow & y(t) = \int_0^t C e^{(t-\tau) A} Bu(\tau) d \tau = (K * u)(t) \quad \leftarrow x(0) = 0 \\ \end{array} \]


  • 接下来我们假设 \(A \in \mathbb{C}^{N \times N}\) 为一个对角矩阵, 考虑到 \(B \in \mathbb{C}^{N \times 1}, C \in \mathbb{C}^{1 \times N}\), 我们可以令 \(A_n, B_n, C_n\) 对应的第 \(n\) 个元素. 由此一来, 我们就会有

    \[K(t) = \sum_{n=0}^{N-1} C_n K_n(t), \quad K_n(t) := \bm{e}_n^T e^{t A} B, \]

    其中 \(\bm{e}_n \in \{0, 1\}^N\) 表示第 \(n\) 个元素为 1 其余为 0 的向量.

  • 离散化后, 我们有:

    \[y = u * \bar{K}, \quad \bar{K} = (C\bar{B}, C\overline{AB}, \ldots, C\bar{A}^{L-1} \bar{B}) \in \mathbb{C}^L. \]

  • 容易证明:

    \[\bar{K} = [\bar{B}_0 C_0, \ldots, \bar{B}_{N-1} C_{N-1}] \left [ \begin{array}{ccccc} 1 & \bar{A}_0 & \bar{A}_0^2 & \ldots & \bar{A}_0^{L-1} \\ 1 & \bar{A}_1 & \bar{A}_1^2 & \ldots & \bar{A}_1^{L-1} \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 1 & \bar{A}_{N-1} & \bar{A}_{N-1}^2 & \ldots & \bar{A}_{N-1}^{L-1} \\ \end{array} \right ], \]

    这是 Vandermonde matrix-vector multiplication.

  • 正常算, \(\bar{K}\) 需要 \(O(NL)\) 的计算量, 不过 Vandermonde matrix-vector multiplication 实际上有更快的算法, 可以达到 \(O(N + L)\) 的复杂度.

  • 最后, 作者讨论了初始化, \(A\) 可以用 HiPPO 矩阵的 DPLR 后的对角矩阵初始化, 或者用直接用对角线, 以及额外还有两种:

    \[\text{S4D-Inv}: \quad A_n = -\frac{1}{2} + i \frac{N}{\pi} (\frac{N}{2n + 1} - 1), \\ \text{S4D-Lin}: \quad A_n = -\frac{1}{2} + i\pi n. \]

  • 作者给了算法:

代码

[official-code]

posted @ 2024-06-12 14:40  馒头and花卷  阅读(37)  评论(0编辑  收藏  举报