U-net基础代码

image

import torch
import torch.nn as nn
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(ConvBlock, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.elu = nn.ELU()
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
x = self.conv(x)
x = self.elu(x)
x = self.bn(x)
return x
class UNet(nn.Module):
def __init__(self, in_channels, out_channels):
super(UNet, self).__init__()
# Encoder
self.conv1 = ConvBlock(in_channels, 64)
self.conv2 = ConvBlock(64, 128)
self.conv3 = ConvBlock(128, 256)
self.conv4 = ConvBlock(256, 512)
# Decoder
self.upconv5 = nn.ConvTranspose2d(512, 512, kernel_size=2, stride=2)
self.conv6 = ConvBlock(1024, 512)
self.upconv7 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.conv8 = ConvBlock(512, 256)
self.upconv9 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.conv10 = ConvBlock(256, 128)
self.upconv11 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.conv12 = ConvBlock(128, 64)
self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
def forward(self, x):
# Encoder
conv1 = self.conv1(x)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
conv4 = self.conv4(conv3)
# Decoder
up5 = self.upconv5(conv4)
concat5 = torch.cat([up5, conv3], dim=1)
conv6 = self.conv6(concat5)
up7 = self.upconv7(conv6)
concat7 = torch.cat([up7, conv2], dim=1)
conv8 = self.conv8(concat7)
up9 = self.upconv9(conv8)
concat9 = torch.cat([up9, conv1], dim=1)
conv10 = self.conv10(concat9)
up11 = self.upconv11(conv10)
concat11 = torch.cat([up11, x], dim=1)
conv12 = self.conv12(concat11)
output = self.final_conv(conv12)
return output
# 使用示例
model = UNet(in_channels=3, out_channels=1)
posted @   辛宣  阅读(54)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· winform 绘制太阳,地球,月球 运作规律
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
历史上的今天:
2022-01-16 练习1:定义一个类描述数字时钟。
点击右上角即可分享
微信分享提示