Import required libraries:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF
from torchvision.models import vgg19
from torchvision.datasets import ImageFolder
Define a simple convolutional block (Conv-BatchNorm-ReLU)
class ConvBlock (nn.Module):
def __init__ (self, in_channels, out_channels, kernel_size, stride, padding ):
super (ConvBlock, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True )
)
def forward (self, x ):
return self.conv(x)
Define a simple upscaling block using sub-pixel convolution
class UpscaleBlock (nn.Module):
def __init__ (self, in_channels, scale_factor ):
super (UpscaleBlock, self).__init__()
self.conv = nn.Conv2d(in_channels, in_channels * (scale_factor ** 2 ), kernel_size=3 , stride=1 , padding=1 )
self.pixel_shuffle = nn.PixelShuffle(scale_factor)
self.relu = nn.ReLU(inplace=True )
def forward (self, x ):
x = self.conv(x)
x = self.pixel_shuffle(x)
x = self.relu(x)
return x
Define a custom super-resolution model (e.g., using ConvBlocks and UpscaleBlocks)
class SuperResolutionModel (nn.Module):
def __init__ (self, upscale_factor ):
super (SuperResolutionModel, self).__init__()
self.conv1 = ConvBlock(3 , 64 , kernel_size=9 , stride=1 , padding=4 )
self.conv2 = ConvBlock(64 , 32 , kernel_size=1 , stride=1 , padding=0 )
self.upscale = UpscaleBlock(32 , upscale_factor)
self.conv3 = nn.Conv2d(32 , 3 , kernel_size=9 , stride=1 , padding=4 )
def forward (self, x ):
x = self.conv1(x)
x = self.conv2(x)
x = self.upscale(x)
x = self.conv3(x)
return x
Create a custom dataset for image super-resolution
class SuperResolutionDataset (torch.utils.data.Dataset):
def __init__ (self, image_folder, input_transform, target_transform ):
self.dataset = ImageFolder(image_folder)
self.input_transform = input_transform
self.target_transform = target_transform
def __getitem__ (self, index ):
img, _ = self.dataset[index]
target = self.target_transform(img)
input = self.input_transform(target)
return input , target
def __len__ (self ):
return len (self.dataset)
Instantiate the model, loss function, and optimizer
upscale_factor = 2
model = SuperResolutionModel(upscale_factor).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4 )
input_transform = transforms.Compose([
transforms.Resize((256 // upscale_factor, 256 // upscale_factor), interpolation=TF.InterpolationMode.BICUBIC),
transforms.ToTensor()
])
target_transform = transforms.Compose([
transforms.Resize((256 , 256 ), interpolation=TF.InterpolationMode.BICUBIC),
transforms.ToTensor()
])
Create DataLoader for training and validation data
train_dataset = SuperResolutionDataset("path/to/train_data" , input_transform, target_transform)
train_loader = DataLoader(train_dataset, batch_size=16 , shuffle=True , num_workers=4 )
val_dataset = SuperResolutionDataset("path/to/val_data" , input_transform, target_transform)
val_loader = DataLoader(val_dataset, batch_size=16 , shuffle=False , num_workers=4 )
Training loop
model.eval ()
val_loss = 0.0
with torch.no_grad():
for inputs, targets in val_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
val_loss += loss.item()
val_loss /= len (val_loader)
print (f"Validation Loss: {val_loss:.4 f} " )
Validation loop
model.eval ()
val_loss = 0.0
with torch.no_grad():
for inputs, targets in val_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
val_loss += loss.item()
val_loss /= len (val_loader)
print (f"Validation Loss: {val_loss:.4 f} " )
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 无需6万激活码!GitHub神秘组织3小时极速复刻Manus,手把手教你使用OpenManus搭建本
· C#/.NET/.NET Core优秀项目和框架2025年2月简报
· DeepSeek在M芯片Mac上本地化部署