利用 FCN 使得 ResNet 允许任意大小图片输入

阅读这个网站写的一些备忘。

通过少量修改 ResNet18 网络结构的形式,对全卷积网络方案一窥究竟。

允许网络输入任意大小的图像#

一般的卷积网络,会因为全连接层 nn.Linear 的存在,而仅允许固定大小的图像输入。

全卷积网络 FCN 使用 1×1 的卷积核,回避了全连接层的缺陷。

不摒弃全连接层的解决方法#

ResNet 的 torchvision 实现中,在最后的全连接层之前有一个 nn.AdaptiveAvgPool2d((1, 1))

class ResNet:
 
    # ...
    self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
    self.bn1 = norm_layer(self.inplanes)
    self.relu = nn.ReLU(inplace=True)
    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    self.layer1 = self._make_layer(block, 64, layers[0])
    self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate = replace_stride_with_dilation[0])
    self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate = replace_stride_with_dilation[1])
    self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate = replace_stride_with_dilation[2])
    self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
    self.fc = nn.Linear(512 * block.expansion, num_classes)

nn.AdaptiveAvgPool2d((1, 1)) 会让每个通道上只有一个 (1, 1) 的像素点,保证全连接层接收的参数数量固定。输入图像因此不局限于固定长宽比和大小。

例如,输入 tensor 大小为 8×5×16 或 2×2×16,经过此层会都变为 1×1×16

使用 FCN 替换全连接层#

对上一章的 ResNet 代码进行修改。

具体来说:

  • nn.AdaptiveAvgPool2d((1, 1))修改为 nn.AvgPool2d((7, 7))
  • 将全连接层换为 torch.nn.Conv2d( in_channels = self.fc.in_features, out_channels = num_classes, kernel_size = 1),即卷积核为 1 的、输入输出维度不变的 二维卷积层
class FullyConvolutionalResnet18(models.ResNet):
    def __init__(self, num_classes=1000, pretrained=False, **kwargs):
        # Start with standard resnet18 defined here 
        super().__init__(block = models.resnet.BasicBlock, layers = [2, 2, 2, 2], num_classes = num_classes, **kwargs)
        if pretrained:
            state_dict = load_state_dict_from_url( models.resnet.model_urls["resnet18"], progress=True)
            self.load_state_dict(state_dict)
 
        # Replace AdaptiveAvgPool2d with standard AvgPool2d 
        self.avgpool = nn.AvgPool2d((7, 7))
 
        # Convert the original fc layer to a convolutional layer.  
        self.last_conv = torch.nn.Conv2d( in_channels = self.fc.in_features, out_channels = num_classes, kernel_size = 1)
        self.last_conv.weight.data.copy_( self.fc.weight.data.view ( *self.fc.weight.data.shape, 1, 1))
        self.last_conv.bias.data.copy_ (self.fc.bias.data)

最后三行,是将原来全连接层的权重,转移到了二维卷积层。
全连接层的 weight 权重 size 为 out\_features×in\_features;二维卷积层的 weight 权重 size 为 1×1×out\_features×in\_features。可见,仅需一点变换,就能迁移权重。

整体代码#

import torch
import torch.nn as nn
from torchvision import models
from PIL import Image
import cv2
import numpy as np
from matplotlib import pyplot as plt
from torchvision import transforms
from einops import rearrange, reduce, repeat
class FNC_Resnet18(models.ResNet):
    def __init__(self):
        # 创建 resnet18 的网络结构
        super(FNC_Resnet18, self).__init__(
            block=models.resnet.BasicBlock, layers=[2, 2, 2, 2], num_classes=1000)

        # 需要提前下载权重 https://download.pytorch.org/models/resnet18-f37072fd.pth
        state_dict = torch.load('resnet18-f37072fd.pth')
        self.load_state_dict(state_dict)

        self.avgpool = nn.AvgPool2d((6, 6))

        self.last_conv = torch.nn.Conv2d(
            in_channels=self.fc.in_features, out_channels=self.fc.out_features, kernel_size=1)
        self.last_conv.weight.data.copy_(
            self.fc.weight.data.view(*self.fc.weight.data.shape, 1, 1))
        self.last_conv.bias.data.copy_(self.fc.bias.data)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = self.last_conv(x)

        return x
model = FNC_Resnet18()

# 需要下载 https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt
with open('imagenet_classes.txt') as f:
    labels = [line.strip() for line in f.readlines()]

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])
input_image = Image.open('R.jpg')  # 随意选择一张图片
input_tensor = transform(input_image)
input_batch = input_tensor.unsqueeze(0)

model.eval()
with torch.no_grad():
    output = model(input_batch)
    print(f"output.shape: {output.shape}")
    output = rearrange(output, 'b c h w -> b (h w) c')

    summed_output = output.sum(dim=1)

    # 找到前五个最大值的值和索引
    for batch in summed_output:
        top5_values, top5_indices = torch.topk(batch, 5)
        # 打印 label
        for value, index in zip(top5_values, top5_indices):
            print(f"{labels[index]}: {value}")

作者:chirp

出处:https://www.cnblogs.com/chirp/p/17930794.html

版权:本作品采用「署名-非商业性使用-相同方式共享 4.0 国际」许可协议进行许可。

posted @   倒地  阅读(463)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 25岁的心里话
· 闲置电脑爆改个人服务器(超详细) #公网映射 #Vmware虚拟网络编辑器
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· 零经验选手,Compose 一天开发一款小游戏!
· 一起来玩mcp_server_sqlite,让AI帮你做增删改查!!
more_horiz
keyboard_arrow_up dark_mode palette
选择主题
menu
点击右上角即可分享
微信分享提示