On the Parameterization and Initialization of Diagonal State Space Models
概
Mamba 系列第四作: S4D.
符号说明
- \(u(t) \in \mathbb{R}\), 输入信号;
- \(x(t) \in \mathbb{R}^N\), 中间状态;
- \(y(t) \in \mathbb{R}\), 输出信号
S4D
-
在 LSSL 中我们已经阐述了线性系统:
\[x'(t) = A x(t) + Bu(t), \\ y(t) = C x(t) + Du(t) \]在兼顾 RNN, CNN 的优势的可能性, 并且离散化后说明 LSSL 实际上可以改写成卷积的形式, 从而实现高效的并行化:
\[y = \mathcal{K}_L (\bar{A}, \bar{B}, C) * u + Du, \\ \mathcal{K}_L (A, B, C) := (CB, CAB, \ldots, CA^{L-1}B). \] -
S4 的初衷是对角化 \(A\) 来避免卷积过程中 \(A^l\) 的复杂运算, 不过考虑到完全对角化的一个数值问题, 最终 S4 给出的策略是重参数化 \(A\) 为对角矩阵 + 低秩矩阵.
-
不过最近 DSS 发现通过合理的初始化, 就能够避免数值问题, 本文在此基础上进一步探索.
-
首先作者考虑简化的 ODE:
\[x'(t) = A x(t) + Bu(t), \\ y(t) = C x(t), \]并给出一个等价的卷积形式:
\[K(t) = C e^{tA} B, \\ y(t) = (K * u)(t). \]
proof:
- 首先注意到:\[\begin{array}{ll} & x'(t) = Ax(t) + Bu(t) \\ \Rightarrow & x'(t) - Ax(t) = Bu(t) \\ \Rightarrow & e^{-tA} x'(t) - e^{-tA} Ax(t) = e^{-tA} Bu(t) \\ \Rightarrow & (e^{-tA} x(t))' = e^{-tA} Bu(t) \\ \Rightarrow & e^{-tA} x(t) = e^{-tA}x(0) + \int_0^t e^{-\tau A} Bu(\tau) d \tau \\ \Rightarrow & x(t) = x(0) + \int_0^t e^{(t-\tau) A} Bu(\tau) d \tau \\ \Rightarrow & y(t) = \int_0^t C e^{(t-\tau) A} Bu(\tau) d \tau = (K * u)(t) \quad \leftarrow x(0) = 0 \\ \end{array} \]
-
接下来我们假设 \(A \in \mathbb{C}^{N \times N}\) 为一个对角矩阵, 考虑到 \(B \in \mathbb{C}^{N \times 1}, C \in \mathbb{C}^{1 \times N}\), 我们可以令 \(A_n, B_n, C_n\) 对应的第 \(n\) 个元素. 由此一来, 我们就会有
\[K(t) = \sum_{n=0}^{N-1} C_n K_n(t), \quad K_n(t) := \bm{e}_n^T e^{t A} B, \]其中 \(\bm{e}_n \in \{0, 1\}^N\) 表示第 \(n\) 个元素为 1 其余为 0 的向量.
-
离散化后, 我们有:
\[y = u * \bar{K}, \quad \bar{K} = (C\bar{B}, C\overline{AB}, \ldots, C\bar{A}^{L-1} \bar{B}) \in \mathbb{C}^L. \] -
容易证明:
\[\bar{K} = [\bar{B}_0 C_0, \ldots, \bar{B}_{N-1} C_{N-1}] \left [ \begin{array}{ccccc} 1 & \bar{A}_0 & \bar{A}_0^2 & \ldots & \bar{A}_0^{L-1} \\ 1 & \bar{A}_1 & \bar{A}_1^2 & \ldots & \bar{A}_1^{L-1} \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 1 & \bar{A}_{N-1} & \bar{A}_{N-1}^2 & \ldots & \bar{A}_{N-1}^{L-1} \\ \end{array} \right ], \]这是 Vandermonde matrix-vector multiplication.
-
正常算, \(\bar{K}\) 需要 \(O(NL)\) 的计算量, 不过 Vandermonde matrix-vector multiplication 实际上有更快的算法, 可以达到 \(O(N + L)\) 的复杂度.
-
最后, 作者讨论了初始化, \(A\) 可以用 HiPPO 矩阵的 DPLR 后的对角矩阵初始化, 或者用直接用对角线, 以及额外还有两种:
\[\text{S4D-Inv}: \quad A_n = -\frac{1}{2} + i \frac{N}{\pi} (\frac{N}{2n + 1} - 1), \\ \text{S4D-Lin}: \quad A_n = -\frac{1}{2} + i\pi n. \] -
作者给了算法: