Pytorch构建超分辨率模型——常用模块

Import required libraries:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF
from torchvision.models import vgg19
from torchvision.datasets import ImageFolder

Define a simple convolutional block (Conv-BatchNorm-ReLU)

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

Define a simple upscaling block using sub-pixel convolution

class UpscaleBlock(nn.Module):
    def __init__(self, in_channels, scale_factor):
        super(UpscaleBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * (scale_factor ** 2), kernel_size=3, stride=1, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(scale_factor)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.relu(x)
        return x

Define a custom super-resolution model (e.g., using ConvBlocks and UpscaleBlocks)

class SuperResolutionModel(nn.Module):
    def __init__(self, upscale_factor):
        super(SuperResolutionModel, self).__init__()
        self.conv1 = ConvBlock(3, 64, kernel_size=9, stride=1, padding=4)
        self.conv2 = ConvBlock(64, 32, kernel_size=1, stride=1, padding=0)
        self.upscale = UpscaleBlock(32, upscale_factor)
        self.conv3 = nn.Conv2d(32, 3, kernel_size=9, stride=1, padding=4)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.upscale(x)
        x = self.conv3(x)
        return x

Create a custom dataset for image super-resolution

class SuperResolutionDataset(torch.utils.data.Dataset):
    def __init__(self, image_folder, input_transform, target_transform):
        self.dataset = ImageFolder(image_folder)
        self.input_transform = input_transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        img, _ = self.dataset[index]
        target = self.target_transform(img)
        input = self.input_transform(target)
        return input, target

    def __len__(self):
        return len(self.dataset)

Instantiate the model, loss function, and optimizer

upscale_factor = 2
model = SuperResolutionModel(upscale_factor).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

Define input and target transformations for data preprocessing

input_transform = transforms.Compose([
    transforms.Resize((256 // upscale_factor, 256 // upscale_factor), interpolation=TF.InterpolationMode.BICUBIC),
    transforms.ToTensor()
])

target_transform = transforms.Compose([
    transforms.Resize((256, 256), interpolation=TF.InterpolationMode.BICUBIC),
transforms.ToTensor()
])

Create DataLoader for training and validation data

train_dataset = SuperResolutionDataset("path/to/train_data", input_transform, target_transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)

val_dataset = SuperResolutionDataset("path/to/val_data", input_transform, target_transform)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)

Training loop

model.eval()
val_loss = 0.0

with torch.no_grad():
    for inputs, targets in val_loader:
        inputs, targets = inputs.to(device), targets.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, targets)

        val_loss += loss.item()

val_loss /= len(val_loader)
print(f"Validation Loss: {val_loss:.4f}")

Validation loop

model.eval()
val_loss = 0.0

with torch.no_grad():
    for inputs, targets in val_loader:
        inputs, targets = inputs.to(device), targets.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, targets)

        val_loss += loss.item()

val_loss /= len(val_loader)
print(f"Validation Loss: {val_loss:.4f}")
posted @   马路野狼  阅读(140)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 无需6万激活码!GitHub神秘组织3小时极速复刻Manus,手把手教你使用OpenManus搭建本
· C#/.NET/.NET Core优秀项目和框架2025年2月简报
· DeepSeek在M芯片Mac上本地化部署
点击右上角即可分享
微信分享提示
主题色彩