语义分割-地表建筑物识别

语义分割-地表建筑物识别

赛题和数据下载:零基础入门语义分割-地表建筑物识别-天池大赛-阿里云天池 (aliyun.com)

实验记录

1.赛题理解与baseline

  1. backbone代码

main.py

import numpy as np
import pandas as pd
import os
import numba, cv2,time
from tqdm import tqdm
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
import albumentations as A
import torch
import torch.nn as nn
import torch.utils.data as D
import torchvision
from rle import rle_encode,rle_decode
from Tianchidataset import TianChiDataset
from loss import loss_fn
import argparse
from torchvision import transforms as T

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
def get_model():
    model = torchvision.models.segmentation.fcn_resnet101(True)

    #     pth = torch.load("../input/pretrain-coco-weights-pytorch/fcn_resnet50_coco-1167a1af.pth")
    #     for key in ["aux_classifier.0.weight", "aux_classifier.1.weight", "aux_classifier.1.bias", "aux_classifier.1.running_mean", "aux_classifier.1.running_var", "aux_classifier.1.num_batches_tracked", "aux_classifier.4.weight", "aux_classifier.4.bias"]:
    #         del pth[key]

    model.classifier[4] = nn.Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1))
    return model


@torch.no_grad()
def validation(model, loader, loss_fn):
    losses = []
    model.eval()
    for image, target in loader:
        image, target = image.to(DEVICE), target.float().to(DEVICE)
        output = model(image)['out']
        loss = loss_fn(output, target)
        losses.append(loss.item())

    return np.array(losses).mean()

def parse_args():
    parser = argparse.ArgumentParser(description='Train semantic segmentation network')
    parser.add_argument('--modelDir',
                        help='saved model path name',
                        default="./checkpoints/model_best.pth",
                        type=str)
    parser.add_argument('--data_path',
                        help='dataset path',
                        default='/home/dzh/Desktop/data/dataset/segmentation/tianchi',
                        type=str)
    parser.add_argument('--epoch',
                        help='total train epoch num',
                        default=30,
                        type=int)
    parser.add_argument('--batch_size',
                        help='total train epoch num',
                        default=160,
                        type=int)
    parser.add_argument('--image_size',
                        help='total train epoch num',
                        default=256,
                        type=int)
    parser.add_argument('--gpu_ids',
                        help='gpu ids: e.g. 0  0,1,2, 0,2. use -1 for CPU',
                        default=[0,1,2,3],
                        type=str)
    args=parser.parse_args()
    return args


def main():
    args = parse_args()
    #--------------------------加载数据及数据增强----------------------------
    train_mask = pd.read_csv(os.path.join(args.data_path,'train_mask.csv'), sep='\t', names=['name', 'mask'])
    train_mask['name'] = train_mask['name'].apply(lambda x: os.path.join(args.data_path,'train/') + x)
    mask = rle_decode(train_mask['mask'].iloc[0])
    print(rle_encode(mask) == train_mask['mask'].iloc[0])

    trfm = A.Compose([
        A.Resize(args.image_size, args.image_size),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(),
    ])
    dataset = TianChiDataset(
        train_mask['name'].values,
        train_mask['mask'].fillna('').values,
        trfm, False
    )
    valid_idx, train_idx = [], []
    for i in range(len(dataset)):
        if i % 7 == 0:
            valid_idx.append(i)
        #     else:
        elif i % 7 == 1:
            train_idx.append(i)

    train_ds = D.Subset(dataset, train_idx)
    valid_ds = D.Subset(dataset, valid_idx)
    # define training and validation data loaders
    loader = D.DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=0)
    vloader = D.DataLoader(valid_ds, batch_size=args.batch_size, shuffle=False, num_workers=0)

    #----------------------------加载模型及优化器------------------------------------
    model = get_model()
    model.to(DEVICE)
    model = torch.nn.DataParallel(model, device_ids=args.gpu_ids, output_device=0)

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-3)
    train_loss = []
    if os.path.exists(args.modelDir):
        checkpoint=torch.load(args.modelDir)
        model.load_state_dict(checkpoint['state_dict'])
        if 'epoch' in checkpoint:
            start_epoch=checkpoint['epoch']
        if 'optimizer' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
        if 'train_loss' in checkpoint:
            train_loss = checkpoint['train_loss']
        print("load model from {}".format(args.modelDir))
    else:
        start_epoch = 0
        print("==> no checkpoint found at '{}'".format(args.modelDir))

    # ----------------------------训练-----------------------------------
    header = r'''
            Train | Valid
    Epoch |  Loss |  Loss | Time, m
    '''
    #          Epoch         metrics            time
    raw_line = '{:6d}' + '\u2502{:7.3f}' * 2 + '\u2502{:6.2f}'
    print(header)
    best_loss = 10

    for epoch in range(start_epoch, args.epoch):
        losses = []
        start_time = time.time()
        model.train()
        for image, target in tqdm(loader):
            image, target = image.to(DEVICE), target.float().to(DEVICE)
            optimizer.zero_grad()
            output = model(image)['out']
            loss = loss_fn(output, target)
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
            # print(loss.item())

        vloss = validation(model, vloader, loss_fn)
        print(raw_line.format(epoch, np.array(losses).mean(), vloss,(time.time() - start_time) / 60 ** 1))
        train_loss.append(np.array(losses).mean())
        if vloss < best_loss:
            best_loss = vloss
            state={
                'epoch':epoch,
                'state_dict':model.state_dict(),
                'optimizer':optimizer.state_dict(),
                'train_loss':train_loss

            }
            torch.save(state,args.modelDir)

    plt.figure(figsize=(10, 5))
    plt.title("Loss During Training")
    plt.plot(train_loss, label="loss")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()
    plt.savefig('./loss.png')
    #--------------------------------验证-----------------------------------
def valid():
    args = parse_args()
    trfm = T.Compose([
        T.ToPILImage(),
        T.Resize(args.image_size),
        T.ToTensor(),
        T.Normalize([0.625, 0.448, 0.688],
                    [0.131, 0.177, 0.101]),
    ])
    subm = []
    model = get_model()
    model.to(DEVICE)
    model = torch.nn.DataParallel(model, device_ids=args.gpu_ids, output_device=0)
    if os.path.exists(args.modelDir):
        checkpoint = torch.load(args.modelDir)
        model.load_state_dict(checkpoint['state_dict'])
        print("load model from {}".format(args.modelDir))
    model.eval()
    test_mask = pd.read_csv(os.path.join(args.data_path,'test_a_samplesubmit.csv'), sep='\t', names=['name', 'mask'])
    test_mask['name'] = test_mask['name'].apply(lambda x: os.path.join(args.data_path,'test_a/') + x)

    for idx, name in enumerate(tqdm(test_mask['name'].iloc[:])):
        image = cv2.imread(name)
        image = trfm(image)
        with torch.no_grad():
            image = image.to(DEVICE)[None]
            score = model(image)['out'][0][0]
            score_sigmoid = score.sigmoid().cpu().numpy()
            score_sigmoid = (score_sigmoid > 0.5).astype(np.uint8)
            score_sigmoid = cv2.resize(score_sigmoid, (512, 512))
            # break
        subm.append([name.split('/')[-1], rle_encode(score_sigmoid)])
    subm = pd.DataFrame(subm)
    subm.to_csv('./tmp.csv', index=None, header=None, sep='\t')
    # plt.imsave('./output.png',rle_decode(subm[1].fillna('').iloc[0]), cmap='gray')

if __name__ == '__main__':

    main()
    valid()

Tianchidataset.py

import torch.utils.data as D
import cv2
from torchvision import transforms as T
from rle import rle_decode
IMAGE_SIZE = 256
class TianChiDataset(D.Dataset):
    def __init__(self, paths, rles, transform, test_mode=False):
        self.paths = paths
        self.rles = rles
        self.transform = transform
        self.test_mode = test_mode

        self.len = len(paths)
        self.as_tensor = T.Compose([
            T.ToPILImage(),
            T.Resize(IMAGE_SIZE),
            T.ToTensor(),
            T.Normalize([0.625, 0.448, 0.688],
                        [0.131, 0.177, 0.101]),
        ])

    # get data operation
    def __getitem__(self, index):
        img = cv2.imread(self.paths[index])
        if not self.test_mode:
            mask = rle_decode(self.rles[index])
            augments = self.transform(image=img, mask=mask)
            return self.as_tensor(augments['image']), augments['mask'][None]
        else:
            return self.as_tensor(img), ''

    def __len__(self):
        """
        Total number of samples in the dataset
        """
        return self.len

loss.py

import torch.nn as nn
class SoftDiceLoss(nn.Module):
    def __init__(self, smooth=1., dims=(-2, -1)):
        super(SoftDiceLoss, self).__init__()
        self.smooth = smooth
        self.dims = dims

    def forward(self, x, y):
        tp = (x * y).sum(self.dims)
        fp = (x * (1 - y)).sum(self.dims)
        fn = ((1 - x) * y).sum(self.dims)

        dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth)
        dc = dc.mean()
        return 1 - dc
def loss_fn(y_pred, y_true):
    bce_fn = nn.BCEWithLogitsLoss()
    dice_fn = SoftDiceLoss()
    bce = bce_fn(y_pred, y_true)
    dice = dice_fn(y_pred.sigmoid(), y_true)
    return 0.8*bce+ 0.2*dice

rle.py

import numpy as np
def rle_encode(im):
    '''
    im: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = im.flatten(order = 'F')
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

def rle_decode(mask_rle, shape=(512, 512)):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape, order='F')

2)提交测试结果

  1. 数据增强方法

A.Rotate(),
A.ShiftScaleRotate(),
A.Cutout(),
# A.RandomScale(),
A.ShiftScaleRotate(),

本次改进在原程序的基础上增加了以上数据增强方法,并选用了deeplabv3_resnet101进行训练,因为服务器占用只训练了这个版本,达到效果为:

posted @ 2021-02-20 15:27  sariel_sakura  阅读(1139)  评论(0编辑  收藏  举报