BN结构分析
-----------------------------------------------2023年12月11日17:04:36---------------------------------------------
测试的时候,是一个一个样本进行测试的,所以没办法求 均值和 方差,所以可以用训练数据的。因为每次做 Mini-Batch 训练时,都会有那个 Mini-Batch 里 m 个训练实例获得的均值和方差,现在要全局统计量,只要把每个Mini-Batch的均值和方差统计量记住,然后 均值采用训练集所有 batch 均值的期望,方差采用训练集所有 batch 的方差的无偏估计即可得出全局统计量,即:
最后测试阶段,BN的使用公式就是:
----------------------------------------------分割线---------------------------------------------------------------
import torch def print_named_parameters(model): for k, (name, param) in enumerate(model.named_parameters()): print('[{}] {:<25}: {}'.format(k+1, name, param.shape)) def print_named_buffers(model): for k, (name, module) in enumerate(model.named_buffers()): print('[{}] {:<25}: {}'.format(k+1, name, module.shape)) if __name__ == '__main__': num_batches, num_channels, height, width = 32, 16, 7, 7 x = torch.randn(num_batches, num_channels, height, width) batchnorm2d = torch.nn.BatchNorm2d(num_channels) y = batchnorm2d(x) print(y.shape) print_named_parameters(batchnorm2d) print_named_buffers(batchnorm2d)
# 输出
torch.Size([32, 16, 7, 7])
[1] weight : torch.Size([16])
[2] bias : torch.Size([16])
[1] running_mean : torch.Size([16])
[2] running_var : torch.Size([16])
[3] num_batches_tracked : torch.Size([])
本文来自博客园,作者:海_纳百川,转载请注明原文链接:https://www.cnblogs.com/chentiao/p/16712468.html,如有侵权联系删除