动手实现深度学习(11):Batch Normalization

wps67

 

传送门: https://www.cnblogs.com/greentomlee/p/12314064.html

 

github: Leezhen2014: https://github.com/Leezhen2014/python_deep_learning

 

image

 

八、Batch Normalization

wps68

Batch-normalization是可以使各个层的激活值分布适当,从而顺利进行“学习”。BatchNorm有已下的优点:

1. 可以使得学习快速收敛

2. 不再依赖初始值

3. 抑制过拟合(权重的值抑制)

 

1.1 batch normalization的实现

batch-norm 是以训练的mini-batch为单位,按照mini-batch进行标准化,使得数据分布的均值为wps69=0,方差为wps70=1。

Batch-norm的思路是调整各个层的激活值的分布,使得其拥有适当的广度,是神经网络对数据分布进行标准化的层。

举例:对mini-batch=m的输入样本wps71,求均值wps72和方差wps73后,将B修改成符合标准正态分布的数据wps74。然后对标准化的数据进行缩放和平移。

因此,数据forwad处理的过程如下:

 

wps75

wps76

其backward流程本质上是forwad的导数,可以使用计算图进行推导:

 

wps77

 

  1 class BatchNormalization:
  2     """
  3     http://arxiv.org/abs/1502.03167
  4     """
  5     def __init__(self, gamma, beta, momentum=0.9, running_mean=None, running_var=None):
  6         self.gamma = gamma
  7         self.beta = beta
  8         self.momentum = momentum
  9         self.input_shape = None  # 转换层为4D,全连接层为2D
 10 
 11         # 平均值和方差
 12         self.running_mean = running_mean
 13         self.running_var = running_var
 14 
 15         # backward时使用的中间结果
 16         self.batch_size = None
 17         self.xc = None
 18         self.std = None
 19         self.dgamma = None
 20         self.dbeta = None
 21 
 22     def forward(self, x, train_flg=True):
 23         self.input_shape = x.shape
 24         if x.ndim != 2:
 25             N, C, H, W = x.shape
 26             x = x.reshape(N, -1)
 27 
 28         out = self.__forward(x, train_flg)
 29 
 30         return out.reshape(*self.input_shape)
 31 
 32     def __forward(self, x, train_flg):
 33         if self.running_mean is None:
 34             N, D = x.shape
 35             self.running_mean = np.zeros(D)
 36             self.running_var = np.zeros(D)
 37 
 38         if train_flg: # batch-norm
 39             mu = x.mean(axis=0)
 40             xc = x - mu
 41             var = np.mean(xc ** 2, axis=0)
 42             std = np.sqrt(var + 10e-7)
 43             xn = xc / std
 44 
 45             self.batch_size = x.shape[0]
 46             self.xc = xc
 47             self.xn = xn
 48             self.std = std
 49             self.running_mean = self.momentum * self.running_mean + (1 - self.momentum) * mu
 50             self.running_var = self.momentum * self.running_var + (1 - self.momentum) * var
 51         else:
 52             xc = x - self.running_mean
 53             xn = xc / ((np.sqrt(self.running_var + 10e-7)))
 54 
 55         out = self.gamma * xn + self.beta
 56         return out
 57 
 58     def backward(self, dout):
 59         if dout.ndim != 2:
 60             N, C, H, W = dout.shape
 61             dout = dout.reshape(N, -1)
 62 
 63         dx = self.__backward(dout)
 64 
 65         dx = dx.reshape(*self.input_shape)
 66         return dx
 67 
 68     def __backward(self, dout):
 69         dbeta = dout.sum(axis=0)
 70         dgamma = np.sum(self.xn * dout, axis=0)
 71         dxn = self.gamma * dout
 72         dxc = dxn / self.std
 73         dstd = -np.sum((dxn * self.xc) / (self.std * self.std), axis=0)
 74         dvar = 0.5 * dstd / self.std
 75         dxc += (2.0 / self.batch_size) * self.xc * dvar
 76         dmu = np.sum(dxc, axis=0)
 77         dx = dxc - dmu / self.batch_size
 78 
 79         self.dgamma = dgamma
 80         self.dbeta = dbeta
 81 
 82         return dx

 

1.1 测试batch normalization(基于mnist 数据集)

该程序选了1000个样本作为训练集,使用了16个epoch,将使用bn层和没有使用bn层网络的权重变化进行了对比,如下图所示:

image

下面这张图和上面的一样,

wps78

2020年2月7日星期五

posted @ 2022-09-12 18:34  修雨轩陈  阅读(237)  评论(0编辑  收藏  举报