BN结构分析

-----------------------------------------------2023年12月11日17:04:36---------------------------------------------

测试的时候,是一个一个样本进行测试的,所以没办法求 均值和 方差,所以可以用训练数据的。因为每次做 Mini-Batch 训练时,都会有那个 Mini-Batch 里 m 个训练实例获得的均值和方差,现在要全局统计量,只要把每个Mini-Batch的均值和方差统计量记住,然后 均值采用训练集所有 batch 均值的期望,方差采用训练集所有 batch 的方差的无偏估计即可得出全局统计量,即:

最后测试阶段,BN的使用公式就是:

 

----------------------------------------------分割线---------------------------------------------------------------

 

import torch

def print_named_parameters(model):
    for k, (name, param) in enumerate(model.named_parameters()):
        print('[{}] {:<25}: {}'.format(k+1, name, param.shape))
     

def print_named_buffers(model):
    for k, (name, module) in enumerate(model.named_buffers()):
        print('[{}] {:<25}: {}'.format(k+1, name, module.shape))
        

if __name__ == '__main__':
    num_batches, num_channels, height, width = 32, 16, 7, 7
    x = torch.randn(num_batches, num_channels, height, width)
    batchnorm2d = torch.nn.BatchNorm2d(num_channels)
    y = batchnorm2d(x)
    print(y.shape)
    print_named_parameters(batchnorm2d)
    print_named_buffers(batchnorm2d)
# 输出

torch.Size([32, 16, 7, 7])
[1] weight : torch.Size([16])
[2] bias : torch.Size([16])
[1] running_mean : torch.Size([16])
[2] running_var : torch.Size([16])
[3] num_batches_tracked : torch.Size([])

 

 

posted @ 2022-09-20 20:34  海_纳百川  阅读(76)  评论(0编辑  收藏  举报
本站总访问量