nn.BatchNorm2d的具体实现
参考:https://blog.csdn.net/qq_38253797/article/details/116847588
import torch import torch.nn as nn import torch.nn.functional as F import numpy as np def _bn(): _batch = torch.randn(3, 4, 5, 5) aa = [] bb = [] for c in range(4): aa.append(0 + torch.mean(_batch[:, c, :, :]) * 0.1) bb.append(1 * 0.9 + torch.var(_batch[:, c, :, :]) * 0.1) print(aa) print(bb) m = nn.BatchNorm2d(4, affine=False, momentum=0.1) _a1 = m(_batch) print(_a1.shape) print(m.running_mean) print(m.running_var) if __name__ == '__main__': _bn()