[论文阅读] Replacing softmax with ReLU in Vision Transformers
Pre
title: Replacing softmax with ReLU in Vision Transformers
accepted: Arxiv 2023
paper: https://export.arxiv.org/abs/2309.08586
code: None
关键词:attention, parallelization
阅读理由:Google Deepmind,标题挺有意思
Idea
序列缩放能缓解ReLU等激活函数在attention中替换Softmax提高并行性时导致的性能下降,但不确定为何
Motivation&Solution
- 注意力中的softmax很重要,但妨碍了并行计算
Background
先前的研究发现如果把注意力的softmax换成逐点激活函数ReLU,准确度会有所下降,作者认为是他们没有将其除以序列长度导致的,同时那些方法仍然依靠normalization来使注意力权重总和为1,(仍保留了无法并行的缺点)
Method(Model)
Overview
用ReLU替换,同时将值除以序列长度,能缓解性能的下降
Attention 原本的注意力权重\(\alpha_{i,j}\)计算如下:
其中 L 是序列长度, \(\phi\) 是经典的softmax,而本文就是要探索它的 point-wise 替代。
ReLU-attention 观察到 \(\phi = L^{-1}relu\) 是公式1中softmax有希望的替代,将其称为ReLU-attention
Scaled point-wise attention 本文探索更一般的形式: \(\phi = L^{-\alpha}h,\; \alpha \in [0,1],\; h \in \{relu, relu^2, gelu, softplus, identity, relu6, sigmoid\}\)
Sequence length scaling 除以序列长度这事主要是实验的结果,但作者也给了一定的理论分析:现在的Transformer用的 sotfmax 注意力都要求 \(\sum^L_{j=1}\),这实际上就隐含了 \(\mathbb{E}_j[\alpha_{ij}] = L^{-1}\)。虽然可能不必要,但 \(\phi = L^{-1}relu\) 使得 \(\mathbb{E}_j[\alpha_{ij}]\) 在初始时在 \(O(L^{-1})\) 的量级。维持该条件可能会减轻替换掉softmax后对其他超参数的调整需求。
初始时 q, k 都是 \(O(1)\),因此 \(\frac{\left \langle q_i,k_j \right \rangle }{\sqrt{d}}\) 也是 \(O(1)\)。ReLU指着激活函数能维持\(O(1)\),因此 \(L^{-1}\)对于 \(\mathbb{E}_j[\alpha_{ij}]\) 维持 \(O(L^{-1})\) 是必要的
Experiment
Settings
用了BigVision codebase的训练配置,没有修改超参数。ImageNet-21k训练30epoch,ImageNet-1k训练300epoch,二者都差不多训练了9e5个step
Dataset
ImageNet-21k, ImageNet-1k
Results
图1 将sotfmax替换为relu/seqlen或是用qk-layernorm匹敌视觉Transformer的传统注意力缩放性能。该图展示了模型从小到大训30epoch的结果。
图2 将softmax替换为 Scaled point-wise attention 那种更为一般的形式,观察到\alpha在接近1的时候结果最好,但在此情况下选不出最好的激活函数,为了速度用了ReLU
图3 注意力去掉qk-layernorm,并且使用 $L^{-\alpha}$ 缩放的影响。
图4 使用门控注意力单元加上 $L^{-\alpha}$ 缩放的影响
Main experiment. 图1展示了ReLU-attention在ImageNet-21k上训练时与softmax attention一样的缩放趋势。但看起来性能似乎是稳定地不如。
Effect of sequence length scaling. 见图2
Effect of qk-layernorm. 实验用的qk-layernorm是在计算注意力权重之前把 q,k 拿去过layernorm,据说在提升模型尺寸时能防止不稳定。图3实验了去掉它的效果,表明似乎影响不大。
Effect of adding a gate. 研究加一个门控能否替代用序列长度缩放,但图4表明仍然要缩放才有最好的效果,门不门控没差别(横轴在0的时候并非最优,而且加上门控的线也不是平的)
Conclusion
仍留下许多开放问题,例如仍不确定为何因子 \(L^{-1}\) 会有用,以及该项是否可学。而且也可能有比ReLU更好的激活函数。
Critique
好短。看着google还以为有好东西,结论就是序列缩放有用但不知为何,而且说是提高并行性,不得比较一下吞吐和训练时间?图1有些太笼统了。