BatchNorm推理阶段和Conv合并

一、BN层作用

批量归一化(Batch Normalization,BN)在深度学习中常放在卷积层之后,BN层有以下优点:

  • 减少了人为选择参数。在某些情况下可以取消 dropout 和 L2 正则项参数,或者采取更小的 L2 正则项约束参数;
  • 减少了对学习率的要求。现在我们可以使用初始很大的学习率或者选择了较小的学习率,算法也能够快速训练收敛;
  • 可以不再使用局部响应归一化。BN 本身就是归一化网络(局部响应归一化在 AlexNet 网络中存在);
  • 破坏原来的数据分布,一定程度上缓解过拟合(防止每批训练中某一个样本经常被挑选到,文献说这个可以提高 1% 的精度);
  • 减少梯度消失,加快收敛速度,提高训练精度

二、 BN层算法流程

下面给出的是 BN 算法在训练时的过程

输入:上一层输出结果 $ X = {x_1, x_2, ..., x_m} $,学习参数 $ \gamma, \beta $。

算法流程

  1. 计算上一层输出数据的均值

\[\mu_{\beta} = \frac{1}{m} \sum_{i=1}^m(x_i) \]

其中,$ m $ 是此次训练样本 batch 的大小。

  1. 计算上一层输出数据的标准差

\[\sigma_{\beta}^2 = \frac{1}{m} \sum_{i=1}^m (x_i - \mu_{\beta})^2 \]

  1. 归一化处理,得到

\[\hat x_i = \frac{x_i - \mu_{\beta}}{\sqrt{\sigma_{\beta}^2} + \epsilon} \]

其中 $ \epsilon $ 是为了避免分母为 0 而加进去的接近于 0 的很小值

  1. 重构,对经过上面归一化处理得到的数据进行重构,得到

\[y_i = \gamma \hat x_i + \beta \]

其中,$ \gamma, \beta $ 为可学习参数。

注:上述是 BN 训练时的过程,但是当在测试阶段时,往往只是输入一个样本,没有所谓的均值 $ \mu_{\beta} $ 和标准差 $ \sigma_{\beta}^2 $。此时,均值 $ \mu_{\beta} $ 是计算所有 batch $ \mu_{\beta} $ 值的平均值得到,标准差 $ \sigma_{\beta}^2 $ 采用每个batch $ \sigma_{\beta}^2 $ 的无偏估计得到

三、推理阶段合并BN和conv的原理

如果BN层在卷积层Conv之后,那卷积和BN层可以合并成如下式子。
卷积层

\[Z = W X + B \]

BN层

\[Y = \frac{Z - \mu_{\beta}}{\sqrt{\sigma_{\beta}^2} + \epsilon} \gamma + \beta \]

合并上面两个式子可得:

\[Y = \frac{W\gamma}{\sqrt{\sigma_{\beta}^2} + \epsilon} X + (\frac{B - \mu_{\beta} }{\sqrt{\sigma_{\beta}^2} + \epsilon} \gamma + \beta) \]

\[W^{'} = \frac{W\gamma}{\sqrt{\sigma_{\beta}^2} + \epsilon} \]

\[B^{'} = \frac{B - \mu_{\beta} }{\sqrt{\sigma_{\beta}^2} + \epsilon} \gamma + \beta \]

可得

\[Y = W^{'} X + B^{'} \]

因此只需要更新卷积层的权值和偏置就可以达到合并卷积和BN层的效果。

三、code

import torch
import torch.nn as nn
import torchvision as tv


class DummyModule(nn.Module):
    def __init__(self):
        super(DummyModule, self).__init__()

    def forward(self, x):
        # print("Dummy, Dummy.")
        return x


def fuse(conv, bn):
    w = conv.weight
    mean = bn.running_mean
    var_sqrt = torch.sqrt(bn.running_var + bn.eps)

    gamma= bn.weight
    beta= bn.bias

    if conv.bias is not None:
        b = conv.bias
    else:
        b = mean.new_zeros(mean.shape)

    w = w * (gamma/ var_sqrt).reshape([conv.out_channels, 1, 1, 1])
    b = (b - mean)/var_sqrt * gamma+ beta
    fused_conv = nn.Conv2d(conv.in_channels,
                         conv.out_channels,
                         conv.kernel_size,
                         conv.stride,
                         conv.padding,
                         bias=True)
    fused_conv.weight = nn.Parameter(w)
    fused_conv.bias = nn.Parameter(b)
    return fused_conv


def fuse_conv_and_bn(conv, bn):
    # init
    fused_conv = torch.nn.Conv2d(
        conv.in_channels,
        conv.out_channels,
        kernel_size=conv.kernel_size,
        stride=conv.stride,
        padding=conv.padding,
        bias=True
    )
    # # prepare filters
    w_conv = conv.weight.clone().view(conv.out_channels, -1)
    w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps+bn.running_var)))
    fused_conv.weight = nn.Parameter(torch.mm(w_bn, w_conv).view(fused_conv.weight.size()))
    # # prepare spatial bias
    if conv.bias is not None:
        b_conv = conv.bias
    else:
        b_conv = torch.zeros(conv.weight.size(0))
    b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
    fused_conv.bias = nn.Parameter(torch.matmul(w_bn, b_conv) + b_bn)

    return fused_conv


def fuse_module(m):
    children = list(m.named_children())
    print("***********")
    print(children)
    print("***********")
    c = None
    cn = None

    for name, child in children:
        if isinstance(child, nn.BatchNorm2d):
            # bc = fuse(c, child)
            bc = fuse_conv_and_bn(c, child)
            m._modules[cn] = bc
            m._modules[name] = DummyModule()
            print("==> name: ", name)
            c = None
        elif isinstance(child, nn.Conv2d):
            c = child
            cn = name
        else:
            fuse_module(child)


def test_net(m):
    p = torch.randn([1, 3, 224, 224])
    import time
    s = time.time()
    o_output = m(p)
    print("Original time: ", time.time() - s)

    fuse_module(m)

    s = time.time()
    f_output = m(p)
    print("Fused time: ", time.time() - s)

    print("Max abs diff: ", (o_output - f_output).abs().max().item())
    assert(o_output.argmax() == f_output.argmax())
    # print(o_output[0][0].item(), f_output[0][0].item())
    print("MSE diff: ", nn.MSELoss()(o_output, f_output).item())


def test_layer():
    p = torch.randn([1, 3, 112, 112])
    conv1 = m.conv1
    bn1 = m.bn1
    o_output = bn1(conv1(p))
    fusion = fuse(conv1, bn1)
    f_output = fusion(p)
    print(o_output[0][0][0][0].item())
    print(f_output[0][0][0][0].item())
    print("Max abs diff: ", (o_output - f_output).abs().max().item())
    print("MSE diff: ", nn.MSELoss()(o_output, f_output).item())


if __name__ == "__main__":

    m = tv.models.resnet18(True)
    m.eval()
    print("Layer level test: ")
    test_layer()

    print("============================")
    print("Module level test: ")
    test_net(m)
    # print(m)
    # print(list(m.named_modules()))
    # torch.onnx.export(
    #     m,
    #     (inputs,),
    #     "./after.onnx",
    #     verbose=True
    # )
    # print("converted successed!")

参考链接

https://blog.csdn.net/wfei101/article/details/78635557

https://zhuanlan.zhihu.com/p/49329030

https://pytorch.org/tutorials/intermediate/custom_function_conv_bn_tutorial.html?highlight=batchnorm

https://pytorch.org/tutorials/intermediate/fx_conv_bn_fuser.html?highlight=batchnorm

posted @ 2022-10-27 21:29  半夜打老虎  阅读(1006)  评论(0编辑  收藏  举报