On the Parameterization and Initialization of Diagonal State Space Models

Gu A., Gupta A., Goel K. and Re C. On the parameterization and initialization of diagonal state space models. NeurIPS, 2022.

Mamba 系列第四作: S4D.

符号说明

  • u(t)R, 输入信号;
  • x(t)RN, 中间状态;
  • y(t)R, 输出信号

S4D

  • LSSL 中我们已经阐述了线性系统:

    x(t)=Ax(t)+Bu(t),y(t)=Cx(t)+Du(t)

    在兼顾 RNN, CNN 的优势的可能性, 并且离散化后说明 LSSL 实际上可以改写成卷积的形式, 从而实现高效的并行化:

    y=KL(A¯,B¯,C)u+Du,KL(A,B,C):=(CB,CAB,,CAL1B).

  • S4 的初衷是对角化 A 来避免卷积过程中 Al 的复杂运算, 不过考虑到完全对角化的一个数值问题, 最终 S4 给出的策略是重参数化 A 为对角矩阵 + 低秩矩阵.

  • 不过最近 DSS 发现通过合理的初始化, 就能够避免数值问题, 本文在此基础上进一步探索.

  • 首先作者考虑简化的 ODE:

    x(t)=Ax(t)+Bu(t),y(t)=Cx(t),

    并给出一个等价的卷积形式:

    K(t)=CetAB,y(t)=(Ku)(t).


proof:

  • 首先注意到:

    x(t)=Ax(t)+Bu(t)x(t)Ax(t)=Bu(t)etAx(t)etAAx(t)=etABu(t)(etAx(t))=etABu(t)etAx(t)=etAx(0)+0teτABu(τ)dτx(t)=x(0)+0te(tτ)ABu(τ)dτy(t)=0tCe(tτ)ABu(τ)dτ=(Ku)(t)x(0)=0


  • 接下来我们假设 ACN×N 为一个对角矩阵, 考虑到 BCN×1,CC1×N, 我们可以令 An,Bn,Cn 对应的第 n 个元素. 由此一来, 我们就会有

    K(t)=n=0N1CnKn(t),Kn(t):=enTetAB,

    其中 en{0,1}N 表示第 n 个元素为 1 其余为 0 的向量.

  • 离散化后, 我们有:

    y=uK¯,K¯=(CB¯,CAB¯,,CA¯L1B¯)CL.

  • 容易证明:

    K¯=[B¯0C0,,B¯N1CN1][1A¯0A¯02A¯0L11A¯1A¯12A¯1L11A¯N1A¯N12A¯N1L1],

    这是 Vandermonde matrix-vector multiplication.

  • 正常算, K¯ 需要 O(NL) 的计算量, 不过 Vandermonde matrix-vector multiplication 实际上有更快的算法, 可以达到 O(N+L) 的复杂度.

  • 最后, 作者讨论了初始化, A 可以用 HiPPO 矩阵的 DPLR 后的对角矩阵初始化, 或者用直接用对角线, 以及额外还有两种:

    S4D-Inv:An=12+iNπ(N2n+11),S4D-Lin:An=12+iπn.

  • 作者给了算法:

代码

[official-code]

posted @   馒头and花卷  阅读(74)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
点击右上角即可分享
微信分享提示