动手实现深度学习(11):Batch Normalization
传送门: https://www.cnblogs.com/greentomlee/p/12314064.html
github: Leezhen2014: https://github.com/Leezhen2014/python_deep_learning
八、Batch Normalization
Batch-normalization是可以使各个层的激活值分布适当,从而顺利进行“学习”。BatchNorm有已下的优点:
1. 可以使得学习快速收敛
2. 不再依赖初始值
3. 抑制过拟合(权重的值抑制)
1.1 batch normalization的实现
batch-norm 是以训练的mini-batch为单位,按照mini-batch进行标准化,使得数据分布的均值为=0,方差为=1。
Batch-norm的思路是调整各个层的激活值的分布,使得其拥有适当的广度,是神经网络对数据分布进行标准化的层。
举例:对mini-batch=m的输入样本,求均值和方差后,将B修改成符合标准正态分布的数据。然后对标准化的数据进行缩放和平移。
因此,数据forwad处理的过程如下:
其backward流程本质上是forwad的导数,可以使用计算图进行推导:
1 class BatchNormalization: 2 """ 3 http://arxiv.org/abs/1502.03167 4 """ 5 def __init__(self, gamma, beta, momentum=0.9, running_mean=None, running_var=None): 6 self.gamma = gamma 7 self.beta = beta 8 self.momentum = momentum 9 self.input_shape = None # 转换层为4D,全连接层为2D 10 11 # 平均值和方差 12 self.running_mean = running_mean 13 self.running_var = running_var 14 15 # backward时使用的中间结果 16 self.batch_size = None 17 self.xc = None 18 self.std = None 19 self.dgamma = None 20 self.dbeta = None 21 22 def forward(self, x, train_flg=True): 23 self.input_shape = x.shape 24 if x.ndim != 2: 25 N, C, H, W = x.shape 26 x = x.reshape(N, -1) 27 28 out = self.__forward(x, train_flg) 29 30 return out.reshape(*self.input_shape) 31 32 def __forward(self, x, train_flg): 33 if self.running_mean is None: 34 N, D = x.shape 35 self.running_mean = np.zeros(D) 36 self.running_var = np.zeros(D) 37 38 if train_flg: # batch-norm 39 mu = x.mean(axis=0) 40 xc = x - mu 41 var = np.mean(xc ** 2, axis=0) 42 std = np.sqrt(var + 10e-7) 43 xn = xc / std 44 45 self.batch_size = x.shape[0] 46 self.xc = xc 47 self.xn = xn 48 self.std = std 49 self.running_mean = self.momentum * self.running_mean + (1 - self.momentum) * mu 50 self.running_var = self.momentum * self.running_var + (1 - self.momentum) * var 51 else: 52 xc = x - self.running_mean 53 xn = xc / ((np.sqrt(self.running_var + 10e-7))) 54 55 out = self.gamma * xn + self.beta 56 return out 57 58 def backward(self, dout): 59 if dout.ndim != 2: 60 N, C, H, W = dout.shape 61 dout = dout.reshape(N, -1) 62 63 dx = self.__backward(dout) 64 65 dx = dx.reshape(*self.input_shape) 66 return dx 67 68 def __backward(self, dout): 69 dbeta = dout.sum(axis=0) 70 dgamma = np.sum(self.xn * dout, axis=0) 71 dxn = self.gamma * dout 72 dxc = dxn / self.std 73 dstd = -np.sum((dxn * self.xc) / (self.std * self.std), axis=0) 74 dvar = 0.5 * dstd / self.std 75 dxc += (2.0 / self.batch_size) * self.xc * dvar 76 dmu = np.sum(dxc, axis=0) 77 dx = dxc - dmu / self.batch_size 78 79 self.dgamma = dgamma 80 self.dbeta = dbeta 81 82 return dx
1.1 测试batch normalization(基于mnist 数据集)
该程序选了1000个样本作为训练集,使用了16个epoch,将使用bn层和没有使用bn层网络的权重变化进行了对比,如下图所示:
下面这张图和上面的一样,
2020年2月7日星期五
我心匪石,不可转也。我心匪席,不可卷也。