关于Unet的一些代码

My_dataset

复制代码
## 1.0.1
from torch.utils.data import Dataset
import torch
import os
from PIL import Image



def read_split(root, mode:str = "train"):
    if os.path.exists(root) == False:
        print("--the dataset does not exict.--")
        exit()
    Myclass=[cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
    # print(Myclass)
    # exit()
    #['imagesTr', 'imagesTs', 'labelsTr', 'labelsTs']
    Myclass.sort()
    if mode == "train":
        train_name = [cla for cla in os.listdir(os.path.join(root, Myclass[0]))]
        train_img_path = [os.path.join(root, Myclass[0], name) for name in train_name]
        train_label_path = [os.path.join(root, Myclass[1], name) for name in train_name]
        return train_img_path, train_label_path
    
    elif mode == "test":
        test_name = [cla for cla in os.listdir(os.path.join(root, Myclass[0]))]
        test_img_path = [os.path.join(root, Myclass[0], name) for name in test_name]
        # test_label_path = [os.path.join(root, Myclass[1], name) for name in test_name]
        return test_img_path
    #print(train_label_path,train_img_path,test_img_path,test_label_path)
    # return train_img_path,train_label_path,test_img_path,test_label_path

class My_Dataset(Dataset):
    def __init__(self, img_path: list, label_path: list, transforms= None, dataset_type = "train"):
        self.img_path = img_path
        self.transforms = transforms
        self.dataset_type = dataset_type
        if dataset_type == "train":
            self.label_path = label_path
    def __len__(self):
        return len(self.img_path)
    
    def __getitem__(self, item):
        if self.dataset_type == "train":#read file from the path
            img = Image.open(self.label_path[item]).convert("L")
            label = Image.open(self.label_path[item]).convert("L")
            if self.transforms is not None:#transforms
                img = self.transforms(img)
                label = self.transforms(label)
            return img, label
        
        elif self.dataset_type == "test":
            img = Image.open(self.img_path[item]).convert("L")
            # print(type(img))
            size = img.size
            # print(size)
            # print(type(size))
            if self.transforms is not None:
                img = self.transforms(img)
            return img, self.img_path[item], size
                # label = self.transforms(label)
        # print(img.shape,label.shape)
        # exit()
        #print(img.shape,label.shape)
        # img = img / 256
        # label = label /256
        
    
    @staticmethod
    def collate_fn(batch):
        tmp = tuple(zip(*batch))#解包
        if len(tmp) == 3:
            images, path, size = tmp
            images = torch.stack(images, dim=0)
            return images, path, size
        elif len(tmp) == 2:
            images, labels = tmp
            images = torch.stack(images, dim=0)
            labels = torch.stack(labels, dim=0)
            return images, labels


# path = "./archive"
# read_split(path)
# def Save_Image(data, save_path):
#     array = data.cpu().detach().numpy()
#     print(array.shape)
#     img = Image.fromarray(array, mode="L")
#     img.save(save_path)
#     print("finish")
def unsample_add(size:tuple):
    # print(size)
    if size[0][0] >= size[0][1]:
        if (size[0][0] - size[0][1]) %2 == 0:
            left_add = right_add = int((size[0][0] - size[0][1]) /2)
        elif (size[0][0] - size[0][1]) %2 !=0:
            right_add = int((size[0][0] - size[0][1]) /2)
            left_add = right_add + 1
        return left_add, right_add, True
    else:
        if (size[0][1] - size[0][0]) %2 == 0:
            up_add = down_add = int((size[0][1] - size[0][0]) /2)
        elif (size[0][1] - size[0][0]) %2 !=0:
            up_add = int((size[0][1] - size[0][0]) /2)
            down_add = up_add + 1
        return up_add, down_add, False
复制代码

Network

复制代码
## 1.0.1
import torch.nn as nn
import torch
from torch.nn import functional as F


#conv
class Conv(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(Conv, self).__init__()
        self.conv_layer = nn.Sequential(
            #one
            nn.Conv2d(in_channel, out_channel, 3, 1, 1),
            nn.BatchNorm2d(out_channel),
            # 防止过拟合
            nn.Dropout(0.3),
            nn.LeakyReLU(),

            #two
            nn.Conv2d(out_channel, out_channel, 3, 1, 1),
            nn.BatchNorm2d(out_channel),

            nn.Dropout(0.4),
            nn.LeakyReLU()
        )
    def forward(self, x):
        return self.conv_layer(x)
    

#Down  
class Down(nn.Module):
    def __init__(self, in_channel):
        super(Down, self).__init__()
        self.Down_layer = nn.Sequential(
            #nn.Conv2d(in_channel, in_channel, 3,2,1),#size,strike,padding
            nn.MaxPool2d(2),
            nn.ReLU()
        )
    def forward(self, x):
        return self.Down_layer(x)
#Up
class Up(nn.Module):
    def __init__(self, in_channel):
        super(Up, self).__init__()
        # self.Up_layer = nn.Conv2d(in_channel, in_channel//2, 1,1)
        self.Up_layer = nn.ConvTranspose2d(in_channel, in_channel//2,2,2)

    def forward(self, x, res):
        # up = F.interpolate(x, scale_factor=2, mode="nearest")
        # x = self.Up_layer(up)
        # print(x.shape)
        x = self.Up_layer(x)
        # print(res.shape, x.shape)
        return torch.cat((x, res), dim=1) #拼接


class Unet(nn.Module):
    def __init__(self, in_channel : int, out_channel : int):
        super(Unet, self).__init__()
        #Down
        self.Down_Conv1 = Conv(in_channel, 64)
        self.Down1 = Down(64)

        self.Down_Conv2 = Conv(64,128)
        self.Down2 = Down(128)

        self.Down_Conv3 = Conv(128,256)
        self.Down3 = Down(256)

        self.Down_Conv4 = Conv(256,512)
        self.Down4 = Down(512)

        self.Conv = Conv(512,1024)

        #Up
        self.Up1 = Up(1024)
        self.Up_Conv1 = Conv(1024,512)

        self.Up2 = Up(512)
        self.Up_Conv2 = Conv(512,256)
        
        self.Up3 = Up(256)
        self.Up_Conv3 = Conv(256,128)

        self.Up4 = Up(128)
        self.Up_Conv4 = Conv(128,64)

        self.pred = nn.Conv2d(64,out_channel,3,1,1)#in out size strike padding

    def forward(self, x):
        #Down
        D1 = self.Down_Conv1(x)
        D2 = self.Down_Conv2(self.Down1(D1))
        D3 = self.Down_Conv3(self.Down2(D2))
        D4 = self.Down_Conv4(self.Down3(D3))

        Y = self.Conv(self.Down4(D4))
        # print(Y.shape, D4.shape)
        
        U1 = self.Up_Conv1(self.Up1(Y, D4))
        U2 = self.Up_Conv2(self.Up2(U1, D3))
        U3 = self.Up_Conv3(self.Up3(U2, D2))
        U4 = self.Up_Conv4(self.Up4(U3, D1))

        return F.sigmoid(self.pred(U4))
        # return self.pred(U4)
    

# if __name__ == '__main__':
#     a = torch.randn(2,3,256,256)
#     net = Unet()
#     print(net(a).shape)
复制代码

Train

复制代码
## 1.0.1
import torch
import os
from torchvision import transforms
import My_dataset
from torch.utils.data import DataLoader
from Network import Unet
import torch.nn as nn
# from torch.utils.tensorboard import SummaryWriter

path = "./archive"
#path
train_img_path, train_label_path = My_dataset.read_split(path)
#print(train_img_path,train_label_path,test_img_path,test_label_path)
#transforms

#super_para
LR = 0.0001
Epoch = 1
Batch_size = 8
Num_worker = min([os.cpu_count(), Batch_size if Batch_size>1 else 0,8])
USE_GPU = True
Available = torch.cuda.is_available()
# writer = SummaryWriter(log_dir='runs/MNIST_experiment')

#device
data_transforms = {
        "train": transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(512,antialias=True),
            transforms.CenterCrop((512,512))
            ])
    }

#num_set
train_set = My_dataset.My_Dataset(img_path=train_img_path,
                           label_path=train_label_path,
                           transforms=data_transforms["train"])

train_loader = DataLoader(dataset=train_set,
                              batch_size=Batch_size,
                              shuffle=True,
                              num_workers=Num_worker,
                              collate_fn=train_set.collate_fn)

def main():
    # import netwoek
    unet = Unet(in_channel=1, out_channel=1)
    # define loss function
    loss_function = nn.BCELoss()
    # 优化器
    optimizer=torch.optim.RMSprop(unet.parameters(),lr=LR,weight_decay=1e-8,momentum=0.9)
    #GPU
    if Available and USE_GPU:
        unet = unet.cuda()
        loss_function = loss_function.cuda()


    # e = 1
    for epoch in range(Epoch):
        for data in train_loader:
            images, labels = data
            # print(images.shape)
            # exit()
            if Available and USE_GPU:
                images = images.cuda()
                labels = labels.cuda()
            optimizer.zero_grad()
            output = unet(images)
            # exit()
            output = torch.where(output>0.5, 1.0, 0)
            loss = loss_function(output, labels)
            loss.requires_grad_(True)
            print(loss)
            # loss.requires_grad_(True)
            loss.backward()
            # writer.add_scalar('训练损失值', loss, e)
            optimizer.step()
            # e+=1

    # save para
    torch.save(unet.state_dict(), "unet.pt") 
    print("finish")
if __name__ == '__main__':
    main()

# unet = Unet()
# img = torch.randn(2,1,256,256)
# output = unet(img)
# print(output.shape)
# cv2.imwrite()
复制代码

Test

复制代码
## 1.0.2
import torch
from Network import Unet
import My_dataset
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torchvision.utils import save_image
import os

path = "./archive"
test_img_path = My_dataset.read_split(path, "test")
test_label_path = []#创建一个空列表
Available = torch.cuda.is_available()
USE_GPU = True
mask_path = "./archive/save_mask"

data_transforms = {
        "test":transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(512,antialias=True),
            transforms.CenterCrop((512,512))            
            ])
    }
# test set
test_set = My_dataset.My_Dataset(img_path=test_img_path,
                          label_path=test_label_path,   #label传进去之后,如果为train就读入
                          transforms=data_transforms["test"],
                          dataset_type="test")
# dataloader
test_loader = DataLoader(dataset=test_set,
                             shuffle=True,
                             collate_fn=test_set.collate_fn)

def test():
    unet = Unet(in_channel=1, out_channel=1)
    unet.load_state_dict(torch.load("unet.pt"))
    loss_function = torch.nn.BCEWithLogitsLoss()
    if Available and USE_GPU:
        unet = unet.cuda()
        loss_function = loss_function.cuda()
    

    for data in test_loader:
        images, path,size = data
        print(path, size) # width, height
        name = os.path.split(path[0])[-1]#最后一个
        if Available and USE_GPU:
            images = images.cuda()
            # labels = images.cuda()
        output = unet(images)
        # print(output)
        # exit()
        output = output < 0.5 #高清化
        output = output.float()#节省内存
        resize = transforms.Resize(min(size[0][0],size[0][1]), antialias=True)
        output = resize(output)
        #blnum:判断长边是竖直还是水平
        add_one, add_two, blnum= My_dataset.unsample_add(size) # True:width > height
        if blnum:
            pad = torch.nn.ConstantPad2d(padding=(add_one, add_two, 0, 0), value=0)
        else:
            pad = torch.nn.ConstantPad2d(padding=(0, 0, add_one, add_two), value=0)
        output = pad(output)
        batch, channel, h, w = output.shape
        img = output.reshape((h,w))
        # img = img < 0.9 #高清化
        # img = img.float()#节省内存


        save_path = os.path.join(mask_path, name)
        
        
        # img = img * 255
        # print(img)
        # My_dataset.Save_Image(img, save_path)
        save_image(img, save_path)
        # print("finish")
        # exit()



if __name__ == '__main__':
    test()
复制代码

 

posted @   ZeroHzzzz  阅读(56)  评论(0编辑  收藏  举报
编辑推荐:
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· .NET10 - 预览版1新功能体验(一)
点击右上角即可分享
微信分享提示