Batch Normalization
forward
backword
均值与方差
结果对x的偏导由3部分组成
经过复杂的推导可得:
为非常接近0的正数,处于分子,可以省略
代码实现
def batchnorm_backward_alt(dout, cache): gamma, xhat, variance= cache N, _ = dout.shape dbeta = np.sum(dout, axis=0) dgamma = np.sum(xhat * dout, axis=0) dx = (gamma*((sigma_squared_b+eps)**(-0.5))/N) * (N*dout - x_hat*dgamma - dbeta) # eps可省略return dx, dgamma, dbeta
referance
https://blog.csdn.net/leayc/article/details/77645877
http://costapt.github.io/2016/07/09/batch-norm-alt/