Batch Normalization 以及 Pytorch的实现
Batch Normalization 以及 Pytorch的实现
Internal Covariate Shift
在说明 Batch Normalization 之前, 我们需要知道为什么要引入 Batch Normalization, Batch Normalization可以使得我们在训练阶段使用更大的学习率, 以及不必过分的关心模型参数的初始化. 为什么这么说呢, 因为在非线性神经元, 例如 \(sigmoid\) 函数, 在反向传播的过程中, 使用随机梯度下降的时候, 我们求导的结果不仅与训练数据有关, 而且与模型的参数有关, 例如:
我们知道 \(sigmoid\) 函数的特点是, 在 \(y\) 的绝对值很大的时候, 函数的梯度特别小, 这使得网络进入一个饱和状态, 并且收敛速度十分慢, 而 \(y\) 又受到 \(W\) 和 \(b\) 的影响, 所以模型的参数对模型的训练影响很大.
Internal Covariate Shift 的定义是, 训练过程中由于网络参数的变化而导致的网络激活层输出分布的变化. 通过在训练过程中, 格式化网络输入层的分布, 可以提高训练的速度. 这又是为什么呢, 这是由于反向传播更新参数, 导致每一层的输出分布不确定, 下次反向传播参数的更新方向也不确定, 为了防止网络震荡, 所以我们只能选择较小的学习率, 这里参考 大神博客 感觉讲的十分清楚, 本文着重于 Pytorch 是怎么实现 Batch Normalization 的.
Batch Normalization
Batch Normalization的前向传播
Batch Normalization 的思想是作用在输入层上的, 但是单层网络往往不会考虑, 在多层网络的时候, 我们也更多的考虑中间层, 当然输入层有时候也会考虑. 假设 DNN 某一层的输出有 \(k\) 个神经元, 这意味着该层的输出维度是 \(k\), 然后 Batch_Size 为 \(m\), 表示 \(m\) 个数据, 所谓的Batch Normalization, 就是要对这 \(m\) 个数据点的每一维度做归一下, 每个数据点(vector)的维度是 \(k\). 然后我们将归一化的结果作为下一层的输入.
具体的, \(BN\) 层的输入是 Values of \(x\) over a mini-batch: \(B = {x_1...x_m}\); 需要学习的参数是: \(\gamma , \beta\) , 注意, 这里的 \({x_1...x_m}\) 是 \({x_1^d...x_m^d}\) 的缩写, 表示的是第 \(d\) 个特征上的一个行向量.
输出是 \({y_i = BN_{\gamma},\beta(x_i)}\)
BN 层的计算过程就是:
其操作可以分成2步,
- Standardization:首先对\(m\)个 \(x\) 进行 Standardization,得到 zero mean unit variance的分布 \(\hat{x}\).
- scale and shift:然后再对 \(\hat{x}\) 进行scale and shift,缩放并平移到新的分布 \(y\),具有新的均值 \(\beta\)方差 \(\gamma\).
上面的例子是\(x\) 表示在特定的维度下, 所以\(x\) 可以看成 \(1 \times m\) 的矩阵, 现在假设BN层有\(k\)个输入节点,则\(x\)可构成\(k×m\)大小的矩阵\(X\),BN 层相当于通过行操作将其映射为另一个\(k×m\)大小的矩阵\(Y\),\(Y\) 矩阵的每一特征(维度)都进行一次归一化, 结果就是每一维的数据会归一化为均值为 \(\gamma_i\), 标准差为 \(\beta_i\). 用公式表示就是:
其中 \(i\) 表示的是维度, 那么 \([x_i^{(1)}, x_i^{(2)}, \dots, x_i^{(m)}]\) 可以看作是 mini-batch 组成的一个行向量(表示的是某一维度), 或者说上一层网络的某一个神经元的输出, 而 \([x_1^{(i)}, x_2^{(i)}, \dots, x_k^{(i)}]\) 表示的是上一层网络的输出的单个数据点(样本).
我们可以得出:
- \(\mu\) 和 \(\sigma\) 为当前行(在某一特征下的)的统计量,不可学习.
- \(\gamma_i\) 和 \(\beta_i\) 为待学习的scale和shift参数,用于控制 \(y_i\) 的方差和均值
- BN层中,\(x_i\) 和 \(x_j\) 之间不存在信息交流\((i≠j)\), 即不同的特征, 或者维度之间不存在关系.
卷积层使用 Batch Normalization
卷积层, 例如对于图像的卷积的时候, 我们往往不会考虑每一个像素, 注意, 实际上, 往往每一个像素是作为一个特征, 并且还有其 \(RGB\) 值, 这样特征就更多了, 但是, 我们考虑对于图像来说, 不同位置的像素特征本质上属性是相同的, 这里参考 Feature Map(特征图), 我们将一个特征图内的所有特征做相同的归一化, 也就是说, 上述的 \(\gamma\) 与 \(\beta\) 函数对于每个 Feature Map是相同的, 对不同的Feature Map 是不同的, 个人觉得 Feature Map的大小和CNN 的channels 大小相同, 或者说是 channels 的另一种说法. 我们可以从下面的例子看出他们之间的关系.
import torch
from torch import nn
# 随机生成一个Batch的模拟,100张16通道784像素点的数据
# 均匀分布U(0~1)
x = torch.rand(100, 16, 784)
# Batch Normalization层,因为输入是将高度H和宽度W合成了一个维度,所以这里用1d
layer = nn.BatchNorm1d(16) # 传入通道数, 本质就是 Feature Map的大小, 在 16个Feature Map上做归一化
# 随着 CNN 层数的前进, 这里变成上一层 CNN 中核的个数
out = layer(x)
x = torch.randn(1, 16, 7, 7) # 1张16通道的7乘7的图像
# Batch Normalization层,因为输入是有高度H和宽度W的,所以这里用2d
layer = nn.BatchNorm2d(16) # 传入通道数
out = layer(x)
Batch Normalization的反向传播
加入了 Batch Normalization, 需要更新反向传播的参数. 根据论文,
注意, 上述的反向传播是在一个维度(特征)上的反向传播, 不同于没有 Batch Normalization 的反向传播, 模型参数在梯度下降的时候, 不需要考虑不同维度问题, 这里需要对每一维度计算反向传播, 更新其 \(\gamma\) 与 \(\beta\) 值, 所以这里计算时间会增加, 但是 BN 可以增大学习率, 总的来说, 计算时间还是会降低. 但是 Batch Normalization 计算所占的比例会增大, 大约 1/4.
Batch Normalization的预测阶段
对于\(\mu\)和 \(\sigma\) ,在训练阶段,它们为当前mini batch的统计量,随着输入 batch 的不同,\(\mu\)和 \(\sigma\) 一直在变化。在预测阶段,输入数据可能只有1条,该使用哪个\(\mu\)和 \(\sigma\) ,或者说,每个BN层的\(\mu\)和 \(\sigma\) 该如何取值, 在 Pytorch 实现中, 我们可以根据参数自主选择, 在后面我们会讲到.
Pytorch 源码的实现
_NormBase 基类
Pytorch 的 Batch Normalization 主要是基于_NormBase
类与 class _BatchNorm(_NormBase):
这两个类来实现的, 其中 _BatchNorm
是继承自 _NormBase
, 其中很多重要的属性以及函数都是继承子 _NormBase
, _BatchNorm
是一层神经网络, 和 Linear
这种类似, 关键是 forward
函数, 我们先看下 _NormBase
类以及其信息.
_NormBase
初始化的时候定义了它的属性:
class _NormBase(Module):
"""Common base of _InstanceNorm and _BatchNorm"""
_version = 2
__constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"]
num_features: int
eps: float
momentum: float
affine: bool
track_running_stats: bool
# WARNING: weight and bias purposely not defined here.
# See https://github.com/pytorch/pytorch/issues/39670
def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True,
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(_NormBase, self).__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
if self.affine:
self.weight = Parameter(torch.empty(num_features, **factory_kwargs))
self.bias = Parameter(torch.empty(num_features, **factory_kwargs))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs))
self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs))
self.running_mean: Optional[Tensor]
self.running_var: Optional[Tensor]
self.register_buffer('num_batches_tracked',
torch.tensor(0, dtype=torch.long,
**{k: v for k, v in factory_kwargs.items() if k != 'dtype'}))
else:
self.register_buffer("running_mean", None)
self.register_buffer("running_var", None)
self.register_buffer("num_batches_tracked", None)
self.reset_parameters()
它的重要属性就是
num_features: int
eps: float
momentum: float
affine: bool
track_running_stats: bool
其中
num_features
很好理解, 就是 Batch Normalization 特征的数目, 例如 CNN 中就是 feature MAP的个数,momentum
: 计算 \(\mu\) 和 \(\sigma\) 的滑动平均系数. 采用下列的公式:
affine
: 表示是否使用, \(\gamma_i\) 和 \(\beta_i\) 参数, 在代码中就是weight
与bias
参数.track_running_stats
: 表示是否记录训练过程中的 \(\mu\) 和 \(\sigma\) , 在预测过程中是否使用, 不使用那么, 在预测过程中, 如代码所示:
def reset_running_stats(self) -> None:
# 将预测过程中的 running_mean 和 running_var 初始化为标注正态分布
if self.track_running_stats:
# running_mean/running_var/num_batches... are registered at runtime depending
# if self.track_running_stats is on
self.running_mean.zero_() # type: ignore[union-attr]
self.running_var.fill_(1) # type: ignore[union-attr]
self.num_batches_tracked.zero_() # type: ignore[union-attr,operator]
- 我觉得还有一点就是向模型中添加参数, 这里的
affine
就是使用的这种方式:
if self.affine:
self.weight = Parameter(torch.empty(num_features, **factory_kwargs))
self.bias = Parameter(torch.empty(num_features, **factory_kwargs))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
这里使用 Parameter() 将参数Parameters 添加到设备中去.
_BatchNorm
_BatchNorm 是 基于 _NormBase
的类, 在网络中, 可以看作是一层, 最重要的部分在 forward
函数中, 我们看一下:
def forward(self, input: Tensor) -> Tensor:
self._check_input_dim(input)
# exponential_average_factor is set to self.momentum
# (when it is available) only so that it gets updated
# in ONNX graph when this node is exported to ONNX. 滑动平均系数,
if self.momentum is None:
exponential_average_factor = 0.0
else:
exponential_average_factor = self.momentum
if self.training and self.track_running_stats:
# TODO: if statement only here to tell the jit to skip emitting this when it is None
if self.num_batches_tracked is not None: # type: ignore[has-type]
# 记录 batch 的个数, 也就是 running mean 的个数
self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore[has-type]
# 使用累计滑动平均, 或者指数滑动平均
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
r"""
Decide whether the mini-batch stats should be used for normalization rather than the buffers.
Mini-batch stats are used in training mode, and in eval mode when buffers are None.
"""
# 这里主要是判断 evaluation 阶段是否使用 mini-batch 计算得到的running mean 与 running variance
if self.training:
bn_training = True
else:
bn_training = (self.running_mean is None) and (self.running_var is None)
r"""
Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are used for normalization (i.e. in eval mode when buffers are not None).
"""
return F.batch_norm(
input,
# If buffers are not to be tracked, ensure that they won't be updated
self.running_mean
if not self.training or self.track_running_stats
else None,
self.running_var if not self.training or self.track_running_stats else None,
self.weight,
self.bias,
bn_training,
exponential_average_factor,
self.eps,
)