关于Unet的一些代码
My_dataset
## 1.0.1 from torch.utils.data import Dataset import torch import os from PIL import Image def read_split(root, mode:str = "train"): if os.path.exists(root) == False: print("--the dataset does not exict.--") exit() Myclass=[cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))] # print(Myclass) # exit() #['imagesTr', 'imagesTs', 'labelsTr', 'labelsTs'] Myclass.sort() if mode == "train": train_name = [cla for cla in os.listdir(os.path.join(root, Myclass[0]))] train_img_path = [os.path.join(root, Myclass[0], name) for name in train_name] train_label_path = [os.path.join(root, Myclass[1], name) for name in train_name] return train_img_path, train_label_path elif mode == "test": test_name = [cla for cla in os.listdir(os.path.join(root, Myclass[0]))] test_img_path = [os.path.join(root, Myclass[0], name) for name in test_name] # test_label_path = [os.path.join(root, Myclass[1], name) for name in test_name] return test_img_path #print(train_label_path,train_img_path,test_img_path,test_label_path) # return train_img_path,train_label_path,test_img_path,test_label_path class My_Dataset(Dataset): def __init__(self, img_path: list, label_path: list, transforms= None, dataset_type = "train"): self.img_path = img_path self.transforms = transforms self.dataset_type = dataset_type if dataset_type == "train": self.label_path = label_path def __len__(self): return len(self.img_path) def __getitem__(self, item): if self.dataset_type == "train":#read file from the path img = Image.open(self.label_path[item]).convert("L") label = Image.open(self.label_path[item]).convert("L") if self.transforms is not None:#transforms img = self.transforms(img) label = self.transforms(label) return img, label elif self.dataset_type == "test": img = Image.open(self.img_path[item]).convert("L") # print(type(img)) size = img.size # print(size) # print(type(size)) if self.transforms is not None: img = self.transforms(img) return img, self.img_path[item], size # label = self.transforms(label) # print(img.shape,label.shape) # exit() #print(img.shape,label.shape) # img = img / 256 # label = label /256 @staticmethod def collate_fn(batch): tmp = tuple(zip(*batch))#解包 if len(tmp) == 3: images, path, size = tmp images = torch.stack(images, dim=0) return images, path, size elif len(tmp) == 2: images, labels = tmp images = torch.stack(images, dim=0) labels = torch.stack(labels, dim=0) return images, labels # path = "./archive" # read_split(path) # def Save_Image(data, save_path): # array = data.cpu().detach().numpy() # print(array.shape) # img = Image.fromarray(array, mode="L") # img.save(save_path) # print("finish") def unsample_add(size:tuple): # print(size) if size[0][0] >= size[0][1]: if (size[0][0] - size[0][1]) %2 == 0: left_add = right_add = int((size[0][0] - size[0][1]) /2) elif (size[0][0] - size[0][1]) %2 !=0: right_add = int((size[0][0] - size[0][1]) /2) left_add = right_add + 1 return left_add, right_add, True else: if (size[0][1] - size[0][0]) %2 == 0: up_add = down_add = int((size[0][1] - size[0][0]) /2) elif (size[0][1] - size[0][0]) %2 !=0: up_add = int((size[0][1] - size[0][0]) /2) down_add = up_add + 1 return up_add, down_add, False
Network
## 1.0.1 import torch.nn as nn import torch from torch.nn import functional as F #conv class Conv(nn.Module): def __init__(self, in_channel, out_channel): super(Conv, self).__init__() self.conv_layer = nn.Sequential( #one nn.Conv2d(in_channel, out_channel, 3, 1, 1), nn.BatchNorm2d(out_channel), # 防止过拟合 nn.Dropout(0.3), nn.LeakyReLU(), #two nn.Conv2d(out_channel, out_channel, 3, 1, 1), nn.BatchNorm2d(out_channel), nn.Dropout(0.4), nn.LeakyReLU() ) def forward(self, x): return self.conv_layer(x) #Down class Down(nn.Module): def __init__(self, in_channel): super(Down, self).__init__() self.Down_layer = nn.Sequential( #nn.Conv2d(in_channel, in_channel, 3,2,1),#size,strike,padding nn.MaxPool2d(2), nn.ReLU() ) def forward(self, x): return self.Down_layer(x) #Up class Up(nn.Module): def __init__(self, in_channel): super(Up, self).__init__() # self.Up_layer = nn.Conv2d(in_channel, in_channel//2, 1,1) self.Up_layer = nn.ConvTranspose2d(in_channel, in_channel//2,2,2) def forward(self, x, res): # up = F.interpolate(x, scale_factor=2, mode="nearest") # x = self.Up_layer(up) # print(x.shape) x = self.Up_layer(x) # print(res.shape, x.shape) return torch.cat((x, res), dim=1) #拼接 class Unet(nn.Module): def __init__(self, in_channel : int, out_channel : int): super(Unet, self).__init__() #Down self.Down_Conv1 = Conv(in_channel, 64) self.Down1 = Down(64) self.Down_Conv2 = Conv(64,128) self.Down2 = Down(128) self.Down_Conv3 = Conv(128,256) self.Down3 = Down(256) self.Down_Conv4 = Conv(256,512) self.Down4 = Down(512) self.Conv = Conv(512,1024) #Up self.Up1 = Up(1024) self.Up_Conv1 = Conv(1024,512) self.Up2 = Up(512) self.Up_Conv2 = Conv(512,256) self.Up3 = Up(256) self.Up_Conv3 = Conv(256,128) self.Up4 = Up(128) self.Up_Conv4 = Conv(128,64) self.pred = nn.Conv2d(64,out_channel,3,1,1)#in out size strike padding def forward(self, x): #Down D1 = self.Down_Conv1(x) D2 = self.Down_Conv2(self.Down1(D1)) D3 = self.Down_Conv3(self.Down2(D2)) D4 = self.Down_Conv4(self.Down3(D3)) Y = self.Conv(self.Down4(D4)) # print(Y.shape, D4.shape) U1 = self.Up_Conv1(self.Up1(Y, D4)) U2 = self.Up_Conv2(self.Up2(U1, D3)) U3 = self.Up_Conv3(self.Up3(U2, D2)) U4 = self.Up_Conv4(self.Up4(U3, D1)) return F.sigmoid(self.pred(U4)) # return self.pred(U4) # if __name__ == '__main__': # a = torch.randn(2,3,256,256) # net = Unet() # print(net(a).shape)
Train
## 1.0.1 import torch import os from torchvision import transforms import My_dataset from torch.utils.data import DataLoader from Network import Unet import torch.nn as nn # from torch.utils.tensorboard import SummaryWriter path = "./archive" #path train_img_path, train_label_path = My_dataset.read_split(path) #print(train_img_path,train_label_path,test_img_path,test_label_path) #transforms #super_para LR = 0.0001 Epoch = 1 Batch_size = 8 Num_worker = min([os.cpu_count(), Batch_size if Batch_size>1 else 0,8]) USE_GPU = True Available = torch.cuda.is_available() # writer = SummaryWriter(log_dir='runs/MNIST_experiment') #device data_transforms = { "train": transforms.Compose([ transforms.ToTensor(), transforms.Resize(512,antialias=True), transforms.CenterCrop((512,512)) ]) } #num_set train_set = My_dataset.My_Dataset(img_path=train_img_path, label_path=train_label_path, transforms=data_transforms["train"]) train_loader = DataLoader(dataset=train_set, batch_size=Batch_size, shuffle=True, num_workers=Num_worker, collate_fn=train_set.collate_fn) def main(): # import netwoek unet = Unet(in_channel=1, out_channel=1) # define loss function loss_function = nn.BCELoss() # 优化器 optimizer=torch.optim.RMSprop(unet.parameters(),lr=LR,weight_decay=1e-8,momentum=0.9) #GPU if Available and USE_GPU: unet = unet.cuda() loss_function = loss_function.cuda() # e = 1 for epoch in range(Epoch): for data in train_loader: images, labels = data # print(images.shape) # exit() if Available and USE_GPU: images = images.cuda() labels = labels.cuda() optimizer.zero_grad() output = unet(images) # exit() output = torch.where(output>0.5, 1.0, 0) loss = loss_function(output, labels) loss.requires_grad_(True) print(loss) # loss.requires_grad_(True) loss.backward() # writer.add_scalar('训练损失值', loss, e) optimizer.step() # e+=1 # save para torch.save(unet.state_dict(), "unet.pt") print("finish") if __name__ == '__main__': main() # unet = Unet() # img = torch.randn(2,1,256,256) # output = unet(img) # print(output.shape) # cv2.imwrite()
Test
## 1.0.2 import torch from Network import Unet import My_dataset from torchvision import transforms from torch.utils.data import DataLoader import matplotlib.pyplot as plt from torchvision.utils import save_image import os path = "./archive" test_img_path = My_dataset.read_split(path, "test") test_label_path = []#创建一个空列表 Available = torch.cuda.is_available() USE_GPU = True mask_path = "./archive/save_mask" data_transforms = { "test":transforms.Compose([ transforms.ToTensor(), transforms.Resize(512,antialias=True), transforms.CenterCrop((512,512)) ]) } # test set test_set = My_dataset.My_Dataset(img_path=test_img_path, label_path=test_label_path, #label传进去之后,如果为train就读入 transforms=data_transforms["test"], dataset_type="test") # dataloader test_loader = DataLoader(dataset=test_set, shuffle=True, collate_fn=test_set.collate_fn) def test(): unet = Unet(in_channel=1, out_channel=1) unet.load_state_dict(torch.load("unet.pt")) loss_function = torch.nn.BCEWithLogitsLoss() if Available and USE_GPU: unet = unet.cuda() loss_function = loss_function.cuda() for data in test_loader: images, path,size = data print(path, size) # width, height name = os.path.split(path[0])[-1]#最后一个 if Available and USE_GPU: images = images.cuda() # labels = images.cuda() output = unet(images) # print(output) # exit() output = output < 0.5 #高清化 output = output.float()#节省内存 resize = transforms.Resize(min(size[0][0],size[0][1]), antialias=True) output = resize(output) #blnum:判断长边是竖直还是水平 add_one, add_two, blnum= My_dataset.unsample_add(size) # True:width > height if blnum: pad = torch.nn.ConstantPad2d(padding=(add_one, add_two, 0, 0), value=0) else: pad = torch.nn.ConstantPad2d(padding=(0, 0, add_one, add_two), value=0) output = pad(output) batch, channel, h, w = output.shape img = output.reshape((h,w)) # img = img < 0.9 #高清化 # img = img.float()#节省内存 save_path = os.path.join(mask_path, name) # img = img * 255 # print(img) # My_dataset.Save_Image(img, save_path) save_image(img, save_path) # print("finish") # exit() if __name__ == '__main__': test()
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· .NET10 - 预览版1新功能体验(一)