LLM大模型: mamba的诞生和改进点
transformer的核心attention机制效果出奇地好,诞生了chatGPT这种里程碑式milestone的产品,但是attention机制本身的缺点也很明显:time & space complexity 高达 O(N^2); transformer架构2017年发的论文,至今已7年了,时间和空间复杂度的问题有没有解决了?
要想替代attention,核心是“取其精华,去其糟粕”,精华有:
- attention机制下,第t个token可以通过点积得到weight的方式,找到前面t-1个token中哪些token的信息要多保存一些,哪些token的信息要少保存甚至丢掉;也就是精准地选择前面t-1个token中哪些信息重要,需要保存,哪些信息不重要,可以减弱甚至遗忘丢弃!
- 可以并行计算
糟粕有:
- time & space complexity 高达 O(N^2)
现在问题来了:什么样的网络结构能保留优点、去掉缺点了?
1、先来看看并行计算的问题:token之间的点积是互相独立的,这是attention能并行的根本原因;点积本质是计算两个分布之间的距离,也就是这两个分布是否接近,这和卷积的作用不能说非常相似,只能说一摸一样!所以卷积是不是能替换点积,同时又能达到并行计算的效果了?
再来看上文信息保存:t 时刻token的生成,肯定要依赖t-1时刻之前所有token的信息,有些token的信息需要保存甚至加强,有些token的信息需要减弱甚至去掉,这个怎么实现了?attention是在value向量中保存上文相关token信息的,如果不用attention机制,这些上文相关token的信息怎么保存和更新了? 传统的MLP,直接根据input的数据计算后得到output数据,这么干简单粗暴,但是没法记录t-1之前所有token的信息,所以在MLP中间肯定是要增加一个vector或matrix的,用来记录截至t-1时刻的所有token信息,比如下面这样的:
输入X seq,输出 Y seq;每次输入一个token,相关信息先存在hidden层,hidden层的信息也会根据当前输入更新,然后用更新后的数据再进一步生成output,这个idea充分利用了t-1时刻之前的token信息,看起来是不是很完美了?这就是RNN啊!上图过于简陋,很多细节还要完善,比如input的x怎么得到hidden的值?hidden怎么根据最新的数据更新历史存量数据?output y怎么根据hidden生成了?这些细节不想好,光看上面的流程图还是没法work啊!具体的细节展示如下:
从图上可以看出,每一步都是通过矩阵乘法向下传递信息的!途中阴影部分是整个流程的核心,名叫 state space model;硬要翻译成中文,就是状态空间模型。最核心的就是state representation或者说A矩阵啦,相当于attention中的value向量,存储了t-1之前所有token的信息,所以取名state space!整个流程用公式表示如下:
这就完了? 上述的网络结构一点问题都没有么?
- 上述整个流程都是串行的,说好的并行在哪了?
- 中间A是个矩阵,能存储多少信息了?如果t很大,input sequence very long,state representation会不会覆盖掉早期的token信息了?
2、上述公式都是连续形式的,但实际token之间是离散的,所以要想办法把上述的公式先转成离散的。首先要做的就是把连续的输入信号按照一定的间隔切分成离散信号,这个间隔就是步长Δ;然后对A、B矩阵做零阶保持,结果如下:
这里使用k表示token的顺序,不再使用t表示,所以上述的计算公式改变如下:
关于并行比较精妙的地方来了:以y2 token生成为例,套用上面的公式,把递归嵌套层层展开后得到如下公式:
这看着像不像卷积了?把上述公式继续抽象概括如下:
这个公式看起来还是有点抽象,举个例子就清晰了:
根据输入的x0、x1、x2来预测输出的y2,不就是个点积么?这不就是卷积么?再归纳一下如下:
简写成如下形式:
A、B、C都是离散的参数,所以卷积核是可以实现计算好保存起来的,而从y1 ~ yk都是可以根据上述公式同时计算出来的!所以说把信号从连续离散化,是可以使用卷积核并行训练的!每个输入的token在生成y时都有各自特定的系数,这些系数完全由A、B、C三个矩阵决定!
3、从上述yk的公式可以看出,yk的生成依赖x0~xk个token的信息,这些信息都要存放在矩阵A里面,那么问题来了:sequence越长,需要存储的token信息就越多,但A矩阵又不可能无限扩大,否则time & space complexity直逼 attention的O(n^2)了,还有啥改进的意义了?怎么解决在有限的存储空间中有效地解决序列建模的长距离依赖问题了?此刻又有大牛站出来,发明了HiPPO:High-order Polynomial Projection Operator!HiPPO 矩阵通常用于将连续时间信号投影到正交多项式基下,以代表过去的状态/信息,用人话解释就是:能在存储空间有限的前提下,有效地存储k-1之前所有token的信息!既能存储k-1号token附近token的信息,又能适当衰减远距离token信息!hippo矩阵可通过函数逼近产生状态矩阵 A 的最优解,公式如下:
根据上述公式,3*3的A矩阵如下:
从上面的矩阵可以很直观地感受到上面三角围为0,相当于“滤波器”,这个矩阵很好的解释了两点:
- 第k个token只能由前面k-1个token决定,未来的token不能影响当前token,所以上三角全是0,也就是weight是0
- 下三角的每个值,直接决定了k-1之前token对第k个token的影响weight,比如A20 = -根号5,意味着第0个token对第2个token的影响weight是-根号5
直观上讲,A矩阵有点像transformer的position矩阵!如果N和K很大,A矩阵岂不是也很大?所以可以进一步做Non-Parametric Low-Rank分解,原论文说明如下:
这里把A分解成两个low rank的P和Q矩阵,可以减少存储空间,减少计算量,这个和lora微调的思路完全一样!所以这一步的核心还是P和Q矩阵;
4、transformer中,为了更好的提取语义特征,attention一般有multi-head,比如8个head,所以原本每个token有512维度,因为要分别进入8个head做处理,所以每个token的512维度均分成8分,每份64维。8个head分别在不同的语义空间做转换,比如apple这个token,究竟是水果了,还是电子消费品?不同的head代表了不同的语义空间,可以有效地根据其context区分apple的语义空间,所以理论上讲,head越多越好(语义空间越精细越好),技术上用512个head都行!借鉴这个思路:state space model的每个dimension都有独立的语义空间,各个矩阵应该怎么改造了??
原始的样本输入:batch size B x sequence length L x embedding dim D,N指SSM的隐藏层维度hidden dimension,也就是state representation的维度;
这里极端一点:如果要让每个D都有一个类似transformer的head,state space model的矩阵形状该怎么改造了?如下:都改成D*N形状的矩阵呗!
每个Dimension都有L个数值用来做representation,这里举个更加形象通俗的例子:比如一张图片有RGB三个通道,也就是3个dimension,如果生成3*64的矩阵,那么每个dimension都有64个数字来做描述和表征,极大地丰富了dimension的信息!
5、这里在上述的思路上再大胆一点:既然每个Dimension都有自己的representation(能让每个维度精细化表征),能不能再扩展一下,让每个token有精细化的表征了?换句话说,每个sequence中不同token的权重应该是不同的(这是mamba和RNN的核心区别之一),比如“我昨天去了隔壁老王家,和老王亲切友好地会谈了很多家务事宜”,这句话有20+token,里面有“了、地”这种语气助词,对语义理解没任何作用,这种词的weight应该接近0;而这句话的主要意思是谈家务,主体是我和老王,所以“谈、家务、我、老王”这几个token的weight应该是最大的!attention机制通过点积找到了token之间的关系,也就是最重要的token,mamba架构应该怎么区分不同token的weight了?
- 本质上讲:A相当于存储器,存储k-1个token的信息;B相当于过滤器,保留当前token有用的信息,去掉没用的,作用和encoder相似;C相当于decoder,把存储器A中的信息输出展示!
token进来之后,分别和A、B、C三个矩阵交互做计算,所以要想每个token都个性化计算weight,就要对这三个矩阵改造!方案如下(右边):
- A是存储过往token信息的矩阵,核心存储了每个dimension的N维的表征信息,不用改形状;
- 步长Δ:通过B*L定位到特定的token后,得到该token的所有dimention信息,和输入token的dimension保持一致
- B和C: 通过B*L定位到特定的token后,得到该token的所有hidden state信息
- 每个token的dimension没了,去哪了?这里简单粗暴,直接用N替换了D,反正都是存储语义信息的,叫什么名字不重要,重要的是存储的内容;因为N > D,所以存储信息的能力增加了!
做了这种改变后,使得每个token都有一个独立的B、C、Δ,从而可以使每个token 都有独立的卷积核,更精细地提取token在hidden representation的特征!也就是说:
- 每个token都有自己独立的D*N做representation
- 每个token的每个dimension都有自己的N做feather representation
这么做的本质还是新增维度数据来更精细地表征token的特征:从D个维度增加到D*L个维度,好比image的pixel翻了很多倍,清晰度也高了很多倍!新增了这么多的维度,这些维度的数据怎么得到了?那就多增加卷积核呗,给每个token都使用单独的卷积核提取D*L个特征!这就是所谓的selective state space model,图示如下:
整个网络架构流程的核心就两点:
- 为每个token选择合适的步长Δ,灵活控制hidden state中信息的更新频率,作用类似遗忘门;较小的步长Δ会忽略当前输入,而更多地使用先前的上文,而较大的步长Δ会更多地关注当前输入而不是上文
- 为每个token选择合适的B、C矩阵(传统的image使用不同的kernel提取不同的特征,比如纹理、斑点等,这里针对不同token也使用不同的kernel提取不同的语义特征)
上述两个目的就是达到有选择性地记忆和遗忘t-1时刻之前的token信息(B和Δ过滤输入token信息,选择性提取有用特征,存储在A中),动态控制sequence中每个token的关注度,让模型自适应处理不同的输入特征!
6、特征提取的问题解决了,又要回到计算效率的问题上啦!上面不是说训练的时候把整个层层递归的流程展开后用cnn并行训练么,推理的时候怎么办了?用同样的方法思路层层展开,举例如下:
图示如下:
看吧,hidden state的计算是不是就能并行了? 先把H0求出来,接下来每个H都能并行了啊!这就是所谓的硬件感知,并行扫描(parallel scan)!
7、最核心的SSSM介绍完毕,mamba的架构如下:
- projection:先升维,让embedding的维度增加,能承载更多的信息,比如更加精细、复杂的特征
- convolution:提取局部特征,和下一步提取中长期依赖关系的SSM形成互补
- selected SSM:选择性地保存中长期依赖关系
- 主干道旁边有个旁路:用于做embedding中的维度选择
总结:
1、mamba和transformer的input和output都是一样的,所以理论上讲transformer能干的事,mamba都能干,都可以平替!
2、这里打个岔,说个题外话:transformer最大的缺陷可能就是time & space complexy O(N^2)了,除了mamba,还有研究员给了一些改进的方案,其中比较出名的就是flash attention了,核心思路是分块:数据量不是大么?那我就分解,把大的数据切块,分成多个小份,每个小份分别放到gpu中不同的core计算,这不就实现了并行计算了么?切分后,复杂度对比如下:
- 因为总的计算量没变,所以时间复杂度没变;但因为分成了很多小份并行计算,所以速度能加快,而且对显存的需求小很多
- 假设B是分块大小,B << N, 那么空间复杂度可降低至 O(N * D)
参考:
1、https://blog.csdn.net/v_JULY_v/article/details/134923301