Batch Normalization Code
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum): if not torch.is_grad_enabled():#In prediction mode X_hat = (X-moving_mean)/torch.sqrt(moving_var+eps) else: assert len(X.shape) in (2,4) if len(X.shape) == 2: mean = X.mean(dim=0) var = ((X-mean)**2).mean(dim=0) else: mean = X.mean(dim=(0,2,3),keepdim=True) var = ((X-mean)**2).mean(dim=(0,2,3),keepdim=True) X_hat = (X-mean)/torch.sqrt(var+eps) moving_mean = momentum*moving_mean + (1.0-momentum)*mean moving_var = momentum*moving_var + (1.0-momentum)*var Y = gamma*X-hat+beta return Y,moving_mean,moving_var
欢迎关注我的CSDN博客心系五道口,有问题请私信2395856915@qq.com