BatchNorm推理阶段和Conv合并
一、BN层作用
批量归一化(Batch Normalization,BN)在深度学习中常放在卷积层之后,BN层有以下优点:
- 减少了人为选择参数。在某些情况下可以取消 dropout 和 L2 正则项参数,或者采取更小的 L2 正则项约束参数;
- 减少了对学习率的要求。现在我们可以使用初始很大的学习率或者选择了较小的学习率,算法也能够快速训练收敛;
- 可以不再使用局部响应归一化。BN 本身就是归一化网络(局部响应归一化在 AlexNet 网络中存在);
- 破坏原来的数据分布,一定程度上缓解过拟合(防止每批训练中某一个样本经常被挑选到,文献说这个可以提高 1% 的精度);
- 减少梯度消失,加快收敛速度,提高训练精度。
二、 BN层算法流程
下面给出的是 BN 算法在训练时的过程
输入:上一层输出结果 ,学习参数 。
算法流程:
- 计算上一层输出数据的均值
其中, 是此次训练样本 batch 的大小。
- 计算上一层输出数据的标准差
- 归一化处理,得到
其中 是为了避免分母为 0 而加进去的接近于 0 的很小值
- 重构,对经过上面归一化处理得到的数据进行重构,得到
其中, 为可学习参数。
注:上述是 BN 训练时的过程,但是当在测试阶段时,往往只是输入一个样本,没有所谓的均值 和标准差 。此时,均值 是计算所有 batch 值的平均值得到,标准差 采用每个batch 的无偏估计得到。
三、推理阶段合并BN和conv的原理
如果BN层在卷积层Conv之后,那卷积和BN层可以合并成如下式子。
卷积层
BN层
合并上面两个式子可得:
令
可得
因此只需要更新卷积层的权值和偏置就可以达到合并卷积和BN层的效果。
三、code
import torch
import torch.nn as nn
import torchvision as tv
class DummyModule(nn.Module):
def __init__(self):
super(DummyModule, self).__init__()
def forward(self, x):
# print("Dummy, Dummy.")
return x
def fuse(conv, bn):
w = conv.weight
mean = bn.running_mean
var_sqrt = torch.sqrt(bn.running_var + bn.eps)
gamma= bn.weight
beta= bn.bias
if conv.bias is not None:
b = conv.bias
else:
b = mean.new_zeros(mean.shape)
w = w * (gamma/ var_sqrt).reshape([conv.out_channels, 1, 1, 1])
b = (b - mean)/var_sqrt * gamma+ beta
fused_conv = nn.Conv2d(conv.in_channels,
conv.out_channels,
conv.kernel_size,
conv.stride,
conv.padding,
bias=True)
fused_conv.weight = nn.Parameter(w)
fused_conv.bias = nn.Parameter(b)
return fused_conv
def fuse_conv_and_bn(conv, bn):
# init
fused_conv = torch.nn.Conv2d(
conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
bias=True
)
# # prepare filters
w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps+bn.running_var)))
fused_conv.weight = nn.Parameter(torch.mm(w_bn, w_conv).view(fused_conv.weight.size()))
# # prepare spatial bias
if conv.bias is not None:
b_conv = conv.bias
else:
b_conv = torch.zeros(conv.weight.size(0))
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
fused_conv.bias = nn.Parameter(torch.matmul(w_bn, b_conv) + b_bn)
return fused_conv
def fuse_module(m):
children = list(m.named_children())
print("***********")
print(children)
print("***********")
c = None
cn = None
for name, child in children:
if isinstance(child, nn.BatchNorm2d):
# bc = fuse(c, child)
bc = fuse_conv_and_bn(c, child)
m._modules[cn] = bc
m._modules[name] = DummyModule()
print("==> name: ", name)
c = None
elif isinstance(child, nn.Conv2d):
c = child
cn = name
else:
fuse_module(child)
def test_net(m):
p = torch.randn([1, 3, 224, 224])
import time
s = time.time()
o_output = m(p)
print("Original time: ", time.time() - s)
fuse_module(m)
s = time.time()
f_output = m(p)
print("Fused time: ", time.time() - s)
print("Max abs diff: ", (o_output - f_output).abs().max().item())
assert(o_output.argmax() == f_output.argmax())
# print(o_output[0][0].item(), f_output[0][0].item())
print("MSE diff: ", nn.MSELoss()(o_output, f_output).item())
def test_layer():
p = torch.randn([1, 3, 112, 112])
conv1 = m.conv1
bn1 = m.bn1
o_output = bn1(conv1(p))
fusion = fuse(conv1, bn1)
f_output = fusion(p)
print(o_output[0][0][0][0].item())
print(f_output[0][0][0][0].item())
print("Max abs diff: ", (o_output - f_output).abs().max().item())
print("MSE diff: ", nn.MSELoss()(o_output, f_output).item())
if __name__ == "__main__":
m = tv.models.resnet18(True)
m.eval()
print("Layer level test: ")
test_layer()
print("============================")
print("Module level test: ")
test_net(m)
# print(m)
# print(list(m.named_modules()))
# torch.onnx.export(
# m,
# (inputs,),
# "./after.onnx",
# verbose=True
# )
# print("converted successed!")
参考链接
https://blog.csdn.net/wfei101/article/details/78635557
https://zhuanlan.zhihu.com/p/49329030
https://pytorch.org/tutorials/intermediate/custom_function_conv_bn_tutorial.html?highlight=batchnorm
https://pytorch.org/tutorials/intermediate/fx_conv_bn_fuser.html?highlight=batchnorm
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 25岁的心里话
· 闲置电脑爆改个人服务器(超详细) #公网映射 #Vmware虚拟网络编辑器
· 零经验选手,Compose 一天开发一款小游戏!
· 因为Apifox不支持离线,我果断选择了Apipost!
· 通过 API 将Deepseek响应流式内容输出到前端