Combining Recurrent, Convolutional, and Continuous-time Models with Linear State-Space Layers

Gu A., Johnson I., Goel K., Saab K., Dao T., Rudra A., and Re C. Combining recurrent, convolutional, and continuous-time models with linear state-space layers. NeurIPS, 2021.

State space representaion-wiki.

Mamba 系列的第二作: LSSL.

符号说明

  • \(u(t) \in \mathbb{R}\), 输入信号;
  • \(x(t) \in \mathbb{R}^N\), 中间状态;
  • \(y(t) \in \mathbb{R}\), 输出信号

LSSL

  • 从 LSSL 开始, 作者开始围绕 linear system 做文章:

    \[\tag{1} \dot{x}(t) = A x(t) + B u(t), \\ y(t) = C x(t) + Du(t), \]

    注意, 这里作者把 \(A, B, C, D\) 简化为和时间 \(t\) 无关的量, 且仅仅讨论的是一维的信号.

  • 采用 generalized bilinear transform (GBT) 可以将上述的 ODE 离散化:

    \[x_t = \bar{A} x_{t-1} + \bar{B} u_t, \\ y_t = C x_t + D u_t, \]

    其中

    \[\bar{A} = (I - \alpha \Delta t \cdot A)^{-1} (I + (1 - \alpha) \Delta t \cdot A), \\ \bar{B} = \Delta t (I - \alpha \Delta t \cdot A)^{-1} B, \]

    \(\Delta t\) 是时间间隔, 而 \(\alpha\) 是一个 bilinear 的超参数. 具体的推导可以见 here.

  • 我们现在仅关注 \(x_t\), 然后看看一个具体的例子. 取 \(A=-1, B=1, \alpha=1, \Delta t = \exp(z)\), 我们有

    \[x_t = \frac{1}{1 + \exp(z)} x_{t-1} + \frac{\exp(z)}{1 + \exp(z)} u_t = (1 - \sigma(z)) x_{t-1} + \sigma(z) u_t. \]

    这实际上就是一个 gating 机制 (常常用在 RNN 的更新上).

  • ok, 对于 RNN, 我们可以把其中的一层看出是对 (1) 的一次近似, 那么多层的效果是什么? 作者认为这和 Picard iteration 有关系, Picard iteration, 即

    \[x_{i+1}(t) := x_i (t_0) + \int_{t_0}^t f(s, x_i(s)) ds \]

    可以证明随着 \(i\) 的增加, 会逐步收敛到真实解, 换言之, 多层的叠加可以让误差越来越小.

  • Deep LSSLs:

    • LSSLs 的具体构造就是上述离散过程的叠加, 同时不同 block 之间添加 skip connection 和 layer norm.
    • 假设我们的输入信号是 \(\mathbb{R}^{L \times H}\) 的, 其中 \(L\) 表示序列长度, \(H\) 是维度, 此时信号不是 1 维的. 作者的做法是, 为每个维度单独设立:

    \[A \in \mathbb{R}^{N \times N}, \quad B \in \mathbb{R}^{N \times 1}, \quad C \in \mathbb{R}^{1 \times N}, \quad D \in \mathbb{R}, \quad \Delta t \in \mathbb{R} \]

    分别进行上述的离散过程. 此外, 这里需要注意的是, \(\Delta t\) 我们也是可学习的.

    • 作者还提到, 输出信号 \(y(t)\) 不一定必须是 1 维的, 也可以是 \(M\) 维的, 此时 \(C \in \mathbb{R}^{M \times N}, D \in \mathbb{R}^{M \times 1}\). 这会导致最后的输出维度是 \(H\cdot M\), 可以通过 MLP 映射回 \(H\).
      所以总共的参数量为:

    \[HNN + HN1 + HMN + HM + H + HMH = \mathcal{O}(HN^2 + HMN + H^2M). \]

注: 作者好像把 LSSL 的代码删掉了, 不过我注意到后续的 S4 的设定里面, 是为每个维度单独设立 \(A, B, C, D\) 还是共享是可以选择的 (也可以是部分维度共享, 取决于 n_ssm 这个参数).

注: \(A\) 的初始化应用 HiPPO.

和其它方法的联系

  • LSSL 除了可以看成是 RNN 外, 实际上还具有卷积的特性, 容易发现:

    \[\begin{array}{ll} y_k &= C(\bar{A})^k \bar{B} u_0 + C (\bar{A})^{k-1} \bar{B} u_1 + \cdots + C \overline{AB} u_{k-1} + \bar{B} u_k + D u_k \\ &= \sum_{s} C(\bar{A})^{k-s} u_s \\ \end{array}, \]

    \[y = \mathcal{K}_L (\bar{A}, \bar{B}, C) * u + Du. \]

    其中

    \[\mathcal{K}_L (A, B, C) := (CB, CAB, \ldots, CA^{L-1}B). \]

  • 所以, LSSL 具有卷积的优点, 可以并行计算.

代码

[official-code]

posted @ 2024-06-11 20:00  馒头and花卷  阅读(43)  评论(0编辑  收藏  举报