Fork me on GitHub

遥感图像多类别语义分割(基于Pytorch-Unet)

遥感图像多类别语义分割(基于Pytorch-Unet)
前言

​ 去年前就对这方面感兴趣了,但是当时只实现了二分类的语义分割,对多类别的语义分割没有研究。这一块,目前还是挺热门的,从FCNUnetdeeplabv3+,模型也是不断更迭。

思路
  1. 首先复现了FCN(VOC2012)的语义分割代码,大概了解了布局。
  2. 然后对二分类的代码进行了修改(基于Pytorch-Unet
核心代码与步骤讲解
  1. dataloader读取

    import torch.utils.data as data
    import PIL.Image as Image
    import os
    import numpy as np
    import torch
    
    def make_dataset(root1, root2):
        '''
        @func: 读取数据,存入列表
        @root1: src路径
        @root2: label路径
        '''
        imgs = []                                    #遍历文件夹,添加图片和标签图片路径到列表
        for i in range(650, 811):
            img = os.path.join(root1, "%s.png" % i)
            mask = os.path.join(root2, "%s.png" % i)
            imgs.append((img, mask))
        return imgs
    
    
    class LiverDataset(data.Dataset):
        '''
        @root1
        @root2
        @transform: 对src做归一化和标准差处理, 数据最后转换成tensor
        @target_transform: 不做处理, label为0/1/2/3(long型)..., 数据最后转换成tensor
        '''
        def __init__(self, root1, root2, transform=None, target_transform=None):
            imgs = make_dataset(root1, root2)             
            self.imgs = imgs
            self.transform = transform
            self.target_transform = target_transform
    
        def __getitem__(self, index):
            x_path, y_path = self.imgs[index]
            img_x = Image.open(x_path)    
            img_y = Image.open(y_path)
            if self.transform is not None:
                img_x = self.transform(img_x)
            if self.target_transform is not None:
                img_y = self.target_transform(img_y)
            else:
                img_y = np.array(img_y) # PIL -> ndarry
                img_y = torch.from_numpy(img_y).long() 
            return img_x, img_y
    
        def __len__(self):
            return len(self.imgs)
    

    这一步里至关重要的就是transform部分。当src是rgb图片,label是0、1、2...单通道灰度图类型(一个值代表一个类别)时。src做归一化和标准差处理,可以提升运算效率和准确性。label则不做处理,转换成long就好。

    1. Unet模型搭建
    import torch.nn as nn
    import torch
    from torch import autograd
    
    class DoubleConv(nn.Module):
        def __init__(self, in_ch, out_ch):
            super(DoubleConv, self).__init__()
            self.conv = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True)
            )
    
        def forward(self, input):
            return self.conv(input)
    
    
    class Unet(nn.Module):
        def __init__(self, in_ch, out_ch):
            super(Unet, self).__init__()
    
            self.conv1 = DoubleConv(in_ch, 64)
            self.pool1 = nn.MaxPool2d(2)
            self.conv2 = DoubleConv(64, 128)
            self.pool2 = nn.MaxPool2d(2)
            self.conv3 = DoubleConv(128, 256)
            self.pool3 = nn.MaxPool2d(2)
            self.conv4 = DoubleConv(256, 512)
            self.pool4 = nn.MaxPool2d(2)
            self.conv5 = DoubleConv(512, 1024)
            self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
            self.conv6 = DoubleConv(1024, 512)
            self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
            self.conv7 = DoubleConv(512, 256)
            self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
            self.conv8 = DoubleConv(256, 128)
            self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
            self.conv9 = DoubleConv(128, 64)
            self.conv10 = nn.Conv2d(64, out_ch, 1)
    
        def forward(self, x):
            c1 = self.conv1(x)
            p1 = self.pool1(c1)
            c2 = self.conv2(p1)
            p2 = self.pool2(c2)
            c3 = self.conv3(p2)
            p3 = self.pool3(c3)
            c4 = self.conv4(p3)
            p4 = self.pool4(c4)
            c5 = self.conv5(p4)
            up_6 = self.up6(c5)
            merge6 = torch.cat([up_6, c4], dim=1)
            c6 = self.conv6(merge6)
            up_7 = self.up7(c6)
            merge7 = torch.cat([up_7, c3], dim=1)
            c7 = self.conv7(merge7)
            up_8 = self.up8(c7)
            merge8 = torch.cat([up_8, c2], dim=1)
            c8 = self.conv8(merge8)
            up_9 = self.up9(c8)
            merge9 = torch.cat([up_9, c1], dim=1)
            c9 = self.conv9(merge9)
            c10 = self.conv10(c9)
            return c10
    
    1. 务必注意,多标签分类输出不做概率化处理(softmax)。原因是后面会用nn.CrossEntropyLoss()计算loss,该函数会自动将net()的输出softmax以及lognllloss()运算。

    2. 然而,当二分类的时候,如果计算损失用的是nn.BCELoss(),由于该函数并未做概率化处理,所以需要单独运算sigmoid,通常会在Unet模型的末尾输出

    3. train & test

    这段比较重要,拆成几段来讲。

    最重要的是nn.CrossEntropyLoss(outputs, label)的输入参数

    outputs: net()输出的结果,在多分类中是没有概率化的值

    label: dataloader读取的标签,此处是单通道灰度数组(0/1/2/3...)。

    这里CrossEntropyLoss函数

    对outputs做softmax + log + nllloss()处理;

    对label做one-hot encoded(转换成多维度的0/1矩阵数组,再参与运算)。

    # 1. train
    def train_model(model, criterion, optimizer, dataload, num_epochs=5):
        for epoch in range(num_epochs):
            print('Epoch {}/{}'.format(epoch, num_epochs-1))
            print('-' * 10)
            dt_size = len(dataload.dataset)
            epoch_loss = 0
            step = 0
            for x, y in dataload:
                step += 1
                inputs = x.to(device)
                labels = y.to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                
                # 可视化输出, 用于debug
                # probs1 = F.softmax(outputs, dim=1)  # 1 7 256 256
                # probs = torch.argmax(probs1, dim=1) # 1 1 256 256 
                # print(0 in probs)
                # print(1 in probs)
                # print(2 in probs)
                # print(3 in probs)
                # print(4 in probs)
                # print(5 in probs)
                # print(probs.max())
                # print(probs.min())
                # print(probs)
    
                # print("\n")
                # print(labels.max())
                # print(labels.min())
    
                # labels 1X256X256
                # outputs 1X7X256X256
                loss = criterion(outputs, labels) 
     # crossentropyloss时outputs会自动softmax,不需要手动计算 / 之前bceloss计算sigmoid是因为bceloss不包含sigmoid函数,需要自行添加
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()
                print("%d/%d,train_loss:%0.3f" % (step, (dt_size - 1) // dataload.batch_size + 1, loss.item()))
            print("epoch %d loss:%0.3f" % (epoch, epoch_loss))
        torch.save(model.state_dict(), 'weights_%d.pth' % epoch)
        return model
    
    def train():
        model = Unet(3, 7).to(device)
        batch_size = args.batch_size
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters())
        liver_dataset = LiverDataset("data/多分类/src", "data/多分类/label", transform=x_transforms, target_transform=y_transforms)
        dataloaders = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True, num_workers=1)
        train_model(model, criterion, optimizer, dataloaders)
    
    # 2. transform使用pytorch内置函数
    x_transforms = transforms.Compose([
        transforms.ToTensor(),  # -> [0,1]
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 
    ])
    y_transforms = None  # label不做处理
    
    # 3.test,输出结果可视化
    def test():
        model = Unet(3, 7)
        model.load_state_dict(torch.load(args.ckp, map_location='cpu'))
        liver_dataset = LiverDataset("data/多分类/src", "data/多分类/label", transform=x_transforms, target_transform=y_transforms)
        dataloaders = DataLoader(liver_dataset, batch_size=1)
        model.eval()
        import matplotlib.pyplot as plt
        plt.ion()
        k = 0
        with torch.no_grad():
            for x, _ in dataloaders:
                y = model(x)  
                
      # 将网络输出的数值转换成概率化数组,再取较大值对应的Index,最后去除第一维维度
                y = F.softmax(y, dim=1)  # 1 7 256 256
                y = torch.argmax(y, dim=1) # 1 1 256 256
                y = torch.squeeze(y).numpy() # 256 256
                
                plt.imshow(y)
                
                # debug
                print(y.max())
                print(y.min())
                print("\n")
                
                skimage.io.imsave('E:/Tablefile/u_net_liver-master_multipleLabels/savetest/{}.jpg'.format(k), y)
                plt.pause(0.1)
                k = k+1
            plt.show()
    
需要注意的地方
  1. 损失函数的选取。

    二分类用BCELoss;多分类用CrossEntropyLoss。

    BCELoss没有做概率化运算(sigmoid)

    CrossEntropyLoss做了softmax + log + nllloss

  2. transform

    src图片做归一化和均值/标准差处理

    label不做处理(单通道数组,0/1/2/3...数值代表类别)

  3. 预测结果不好有可能是loss计算错误的问题,也可能是数据集标注的不够好

  4. 注意计算loss之前的squeeze()函数,用于去掉冗余的维度,使得数组是loss函数需要的shape。(注:BCELoss与CrossEntropy对label的shape要求不同)

  5. 二分类在预测时,net()输出先做sigmoid()概率化处理,然后大于0.5为1,小于0.5为0

结果展示
image-20210119002138462
后记
  1. 还需多复现几个语义分割模型(deeplabv3+/segnet/fcn.../unet+)

  2. 理解模型架构卷积、池化、正则化的具体含义

  3. 掌握调参的技巧(优化器、学习率等)

  4. 掌握迁移学习的方法,节省运算时长

posted @ 2021-01-19 00:28  Rser_ljw  阅读(8077)  评论(9编辑  收藏  举报