利用 FCN 使得 ResNet 允许任意大小图片输入
阅读这个网站写的一些备忘。
通过少量修改 ResNet18 网络结构的形式,对全卷积网络方案一窥究竟。
允许网络输入任意大小的图像#
一般的卷积网络,会因为全连接层 nn.Linear
的存在,而仅允许固定大小的图像输入。
全卷积网络 FCN 使用 1×1 的卷积核,回避了全连接层的缺陷。
不摒弃全连接层的解决方法#
ResNet 的 torchvision 实现中,在最后的全连接层之前有一个 nn.AdaptiveAvgPool2d((1, 1))
。
class ResNet:
# ...
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate = replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate = replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate = replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
nn.AdaptiveAvgPool2d((1, 1))
会让每个通道上只有一个 (1, 1) 的像素点,保证全连接层接收的参数数量固定。输入图像因此不局限于固定长宽比和大小。
例如,输入 tensor 大小为 8×5×16 或 2×2×16,经过此层会都变为 1×1×16
使用 FCN 替换全连接层#
对上一章的 ResNet 代码进行修改。
具体来说:
- 将
nn.AdaptiveAvgPool2d((1, 1))
修改为nn.AvgPool2d((7, 7))
- 将全连接层换为
torch.nn.Conv2d( in_channels = self.fc.in_features, out_channels = num_classes, kernel_size = 1)
,即卷积核为 1 的、输入输出维度不变的 二维卷积层
class FullyConvolutionalResnet18(models.ResNet):
def __init__(self, num_classes=1000, pretrained=False, **kwargs):
# Start with standard resnet18 defined here
super().__init__(block = models.resnet.BasicBlock, layers = [2, 2, 2, 2], num_classes = num_classes, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url( models.resnet.model_urls["resnet18"], progress=True)
self.load_state_dict(state_dict)
# Replace AdaptiveAvgPool2d with standard AvgPool2d
self.avgpool = nn.AvgPool2d((7, 7))
# Convert the original fc layer to a convolutional layer.
self.last_conv = torch.nn.Conv2d( in_channels = self.fc.in_features, out_channels = num_classes, kernel_size = 1)
self.last_conv.weight.data.copy_( self.fc.weight.data.view ( *self.fc.weight.data.shape, 1, 1))
self.last_conv.bias.data.copy_ (self.fc.bias.data)
最后三行,是将原来全连接层的权重,转移到了二维卷积层。
全连接层的weight
权重 size 为;二维卷积层的 weight
权重 size 为。可见,仅需一点变换,就能迁移权重。
整体代码#
import torch
import torch.nn as nn
from torchvision import models
from PIL import Image
import cv2
import numpy as np
from matplotlib import pyplot as plt
from torchvision import transforms
from einops import rearrange, reduce, repeat
class FNC_Resnet18(models.ResNet):
def __init__(self):
# 创建 resnet18 的网络结构
super(FNC_Resnet18, self).__init__(
block=models.resnet.BasicBlock, layers=[2, 2, 2, 2], num_classes=1000)
# 需要提前下载权重 https://download.pytorch.org/models/resnet18-f37072fd.pth
state_dict = torch.load('resnet18-f37072fd.pth')
self.load_state_dict(state_dict)
self.avgpool = nn.AvgPool2d((6, 6))
self.last_conv = torch.nn.Conv2d(
in_channels=self.fc.in_features, out_channels=self.fc.out_features, kernel_size=1)
self.last_conv.weight.data.copy_(
self.fc.weight.data.view(*self.fc.weight.data.shape, 1, 1))
self.last_conv.bias.data.copy_(self.fc.bias.data)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = self.last_conv(x)
return x
model = FNC_Resnet18()
# 需要下载 https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt
with open('imagenet_classes.txt') as f:
labels = [line.strip() for line in f.readlines()]
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
input_image = Image.open('R.jpg') # 随意选择一张图片
input_tensor = transform(input_image)
input_batch = input_tensor.unsqueeze(0)
model.eval()
with torch.no_grad():
output = model(input_batch)
print(f"output.shape: {output.shape}")
output = rearrange(output, 'b c h w -> b (h w) c')
summed_output = output.sum(dim=1)
# 找到前五个最大值的值和索引
for batch in summed_output:
top5_values, top5_indices = torch.topk(batch, 5)
# 打印 label
for value, index in zip(top5_values, top5_indices):
print(f"{labels[index]}: {value}")
作者:chirp
出处:https://www.cnblogs.com/chirp/p/17930794.html
版权:本作品采用「署名-非商业性使用-相同方式共享 4.0 国际」许可协议进行许可。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 25岁的心里话
· 闲置电脑爆改个人服务器(超详细) #公网映射 #Vmware虚拟网络编辑器
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· 零经验选手,Compose 一天开发一款小游戏!
· 一起来玩mcp_server_sqlite,让AI帮你做增删改查!!