Gu A., Johnson I., Goel K., Saab K., Dao T., Rudra A., and Re C. Combining recurrent, convolutional, and continuous-time models with linear state-space layers. NeurIPS, 2021.
State space representaion-wiki.
概
Mamba 系列的第二作: LSSL.
符号说明
- u(t)∈R, 输入信号;
- x(t)∈RN, 中间状态;
- y(t)∈R, 输出信号
LSSL

-
从 LSSL 开始, 作者开始围绕 linear system 做文章:
˙x(t)=Ax(t)+Bu(t),y(t)=Cx(t)+Du(t),(1)
注意, 这里作者把 A,B,C,D 简化为和时间 t 无关的量, 且仅仅讨论的是一维的信号.
-
采用 generalized bilinear transform (GBT) 可以将上述的 ODE 离散化:
xt=¯Axt−1+¯But,yt=Cxt+Dut,
其中
¯A=(I−αΔt⋅A)−1(I+(1−α)Δt⋅A),¯B=Δt(I−αΔt⋅A)−1B,
Δt 是时间间隔, 而 α 是一个 bilinear 的超参数. 具体的推导可以见 here.
-
我们现在仅关注 xt, 然后看看一个具体的例子. 取 A=−1,B=1,α=1,Δt=exp(z), 我们有
xt=11+exp(z)xt−1+exp(z)1+exp(z)ut=(1−σ(z))xt−1+σ(z)ut.
这实际上就是一个 gating 机制 (常常用在 RNN 的更新上).
-
ok, 对于 RNN, 我们可以把其中的一层看出是对 (1) 的一次近似, 那么多层的效果是什么? 作者认为这和 Picard iteration 有关系, Picard iteration, 即
xi+1(t):=xi(t0)+∫tt0f(s,xi(s))ds
可以证明随着 i 的增加, 会逐步收敛到真实解, 换言之, 多层的叠加可以让误差越来越小.
-
Deep LSSLs:
- LSSLs 的具体构造就是上述离散过程的叠加, 同时不同 block 之间添加 skip connection 和 layer norm.
- 假设我们的输入信号是 RL×H 的, 其中 L 表示序列长度, H 是维度, 此时信号不是 1 维的. 作者的做法是, 为每个维度单独设立:
A∈RN×N,B∈RN×1,C∈R1×N,D∈R,Δt∈R
分别进行上述的离散过程. 此外, 这里需要注意的是, Δt 我们也是可学习的.
- 作者还提到, 输出信号 y(t) 不一定必须是 1 维的, 也可以是 M 维的, 此时 C∈RM×N,D∈RM×1. 这会导致最后的输出维度是 H⋅M, 可以通过 MLP 映射回 H.
所以总共的参数量为:
HNN+HN1+HMN+HM+H+HMH=O(HN2+HMN+H2M).
注: 作者好像把 LSSL 的代码删掉了, 不过我注意到后续的 S4 的设定里面, 是为每个维度单独设立 A,B,C,D 还是共享是可以选择的 (也可以是部分维度共享, 取决于 n_ssm
这个参数).
注: A 的初始化应用 HiPPO.
和其它方法的联系

-
LSSL 除了可以看成是 RNN 外, 实际上还具有卷积的特性, 容易发现:
yk=C(¯A)k¯Bu0+C(¯A)k−1¯Bu1+⋯+C¯¯¯¯¯¯¯¯ABuk−1+¯Buk+Duk=∑sC(¯A)k−sus,
故
y=KL(¯A,¯B,C)∗u+Du.
其中
KL(A,B,C):=(CB,CAB,…,CAL−1B).
-
所以, LSSL 具有卷积的优点, 可以并行计算.
代码
[official-code]
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
2023-06-11 Graph Neural Networks Inspired by Classical Iterative Algorithms
2019-06-11 Proximal Algorithms 5 Parallel and Distributed Algorithms