Root Mean Square Layer Normalization
Zhang B. and Sennrich R. Root mean square layer normalization. NIPS, 2019.
概
RMSNorm 节省时间.
RMSNorm
-
假设输入为 \(\mathbf{x} \in \mathbb{R}^m\), 然后
\[\mathbf{a} = \mathbf{W} \mathbf{x} \in \mathbb{R}^{n}, \\ \mathbf{y} = f(\text{Norm}(\mathbf{a}) + \mathbf{b}) \in \mathbb{R}^{n}. \]其中 \(f(\cdot)\) 是 element-wise 的激活函数.
-
LayerNorm 采取的是如下的方式 (注意, 下面的 \(/\) 是 element-wise 的):
\[\text{LayerNorm}(\mathbf{a}) = \frac{\mathbf{a} - \bm{\mu}}{\bm{\sigma}} \odot \mathbf{g}, \]其中
\[\bm{\mu} = \text{mean}(\mathbf{a}), \\ \bm{\sigma} = \sqrt{\text{mean}((\mathbf{a} - \bm{\mu})^2)}. \] -
RMSNorm 采用的是如下的方式:
\[\text{RMSNorm}(\mathbf{a}) = \frac{\mathbf{a}}{\text{RMS}(\mathbf{a})} \odot \mathbf{g}, \]其中
\[\text{RMS}(\mathbf{a}) = \sqrt{\text{mean}(\mathbf{a}^2)}. \] -
由于不用计算均值, RMSNorm 所需的计算时间会少一点, 但是效果是差不多的:
-
此外, RMSNorm 保留了一些重要的不变性:
代码
[official]