【Pytorch】ResNet50中替换BN层为IN层实现

替换BN层为IN层

最近在做实验时,考虑将官方torchvision包中的Resnet模型进行一些更改,ResNet类中有个可选参数_norm_layer可以直接传入nn.InstanceNorm2d,默认为nn.BatchNorm,但是这样更改后,在使用官方的预训练权重时,会发生一些报错,BN层里的一些权重会导致报错,因此用另一种方式实现替换BN层的需求的同时,尽可能使用预训练权重

实现

  1. 定义一个函数来替换 BN 层为 IN 层
import torch.nn as nn
def replace_bn_with_in(module):
"""
遍历网络模块,将 BatchNorm 替换为 InstanceNorm
"""
for name, child in module.named_children():
if isinstance(child, nn.BatchNorm2d):
setattr(module, name, nn.InstanceNorm2d(child.num_features, affine=True))
else:
replace_bn_with_in(child)
  1. 加载预训练的 ResNet 模型
import torchvision.models as models
# 加载预训练的 ResNet 模型(这里以 resnet50 为例)
model = models.resnet50(pretrained=True)
  1. 替换BN为IN
# 将模型中的 BatchNorm 层替换为 InstanceNorm 层
replace_bn_with_in(model)

补充

setattr 函数是 Python 的内置函数,用于设置对象的属性。如果属性不存在,它会创建一个新属性。setattr 函数的使用格式如下:

setattr(object, name, value)
  • 参数
    • object:要设置属性的对象。
    • name:属性的名称,一个字符串。
    • value:要设置的属性值
posted @   chendsome  阅读(40)  评论(0编辑  收藏  举报  
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
点击右上角即可分享
微信分享提示