Combining Recurrent, Convolutional, and Continuous-time Models with Linear State-Space Layers

Gu A., Johnson I., Goel K., Saab K., Dao T., Rudra A., and Re C. Combining recurrent, convolutional, and continuous-time models with linear state-space layers. NeurIPS, 2021.

State space representaion-wiki.

Mamba 系列的第二作: LSSL.

符号说明

  • u(t)R, 输入信号;
  • x(t)RN, 中间状态;
  • y(t)R, 输出信号

LSSL

  • 从 LSSL 开始, 作者开始围绕 linear system 做文章:

    (1)x˙(t)=Ax(t)+Bu(t),y(t)=Cx(t)+Du(t),

    注意, 这里作者把 A,B,C,D 简化为和时间 t 无关的量, 且仅仅讨论的是一维的信号.

  • 采用 generalized bilinear transform (GBT) 可以将上述的 ODE 离散化:

    xt=A¯xt1+B¯ut,yt=Cxt+Dut,

    其中

    A¯=(IαΔtA)1(I+(1α)ΔtA),B¯=Δt(IαΔtA)1B,

    Δt 是时间间隔, 而 α 是一个 bilinear 的超参数. 具体的推导可以见 here.

  • 我们现在仅关注 xt, 然后看看一个具体的例子. 取 A=1,B=1,α=1,Δt=exp(z), 我们有

    xt=11+exp(z)xt1+exp(z)1+exp(z)ut=(1σ(z))xt1+σ(z)ut.

    这实际上就是一个 gating 机制 (常常用在 RNN 的更新上).

  • ok, 对于 RNN, 我们可以把其中的一层看出是对 (1) 的一次近似, 那么多层的效果是什么? 作者认为这和 Picard iteration 有关系, Picard iteration, 即

    xi+1(t):=xi(t0)+t0tf(s,xi(s))ds

    可以证明随着 i 的增加, 会逐步收敛到真实解, 换言之, 多层的叠加可以让误差越来越小.

  • Deep LSSLs:

    • LSSLs 的具体构造就是上述离散过程的叠加, 同时不同 block 之间添加 skip connection 和 layer norm.
    • 假设我们的输入信号是 RL×H 的, 其中 L 表示序列长度, H 是维度, 此时信号不是 1 维的. 作者的做法是, 为每个维度单独设立:

    ARN×N,BRN×1,CR1×N,DR,ΔtR

    分别进行上述的离散过程. 此外, 这里需要注意的是, Δt 我们也是可学习的.

    • 作者还提到, 输出信号 y(t) 不一定必须是 1 维的, 也可以是 M 维的, 此时 CRM×N,DRM×1. 这会导致最后的输出维度是 HM, 可以通过 MLP 映射回 H.
      所以总共的参数量为:

    HNN+HN1+HMN+HM+H+HMH=O(HN2+HMN+H2M).

注: 作者好像把 LSSL 的代码删掉了, 不过我注意到后续的 S4 的设定里面, 是为每个维度单独设立 A,B,C,D 还是共享是可以选择的 (也可以是部分维度共享, 取决于 n_ssm 这个参数).

注: A 的初始化应用 HiPPO.

和其它方法的联系

  • LSSL 除了可以看成是 RNN 外, 实际上还具有卷积的特性, 容易发现:

    yk=C(A¯)kB¯u0+C(A¯)k1B¯u1++CAB¯uk1+B¯uk+Duk=sC(A¯)ksus,

    y=KL(A¯,B¯,C)u+Du.

    其中

    KL(A,B,C):=(CB,CAB,,CAL1B).

  • 所以, LSSL 具有卷积的优点, 可以并行计算.

代码

[official-code]

posted @   馒头and花卷  阅读(48)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
历史上的今天:
2023-06-11 Graph Neural Networks Inspired by Classical Iterative Algorithms
2019-06-11 Proximal Algorithms 5 Parallel and Distributed Algorithms
点击右上角即可分享
微信分享提示