从语义分割模型得到彩色的分割图过程

将训练好的语义分割模型保存下来,重新加载之后

通过这一个操作得到标签;

output = self.model(image)
这里的output即为标签内容,通过重新编码的函数来获得彩色图像.
 1 def decode_segmap(label_mask, dataset, plot=False):
 2     """Decode segmentation class labels into a color image
 3     解码标签,得到彩色的图像
 4     Args:
 5         label_mask (np.ndarray): an (M,N) array of integer values denoting
 6           the class label at each spatial location.
 7         plot (bool, optional): whether to show the resulting color image
 8           in a figure.
 9     Returns:
10         (np.ndarray, optional): the resulting decoded color image.
11     """
12     if dataset == 'pascal' or dataset == 'coco':
13         n_classes = 21
14         label_colours = get_pascal_labels()
15     elif dataset == 'cityscapes':
16         n_classes = 19
17         label_colours = get_cityscapes_labels()
18     else:
19         raise NotImplementedError
20 
21     r = label_mask.copy()
22     g = label_mask.copy()
23     b = label_mask.copy()
24     for ll in range(0, n_classes):
25         r[label_mask == ll] = label_colours[ll, 0]
26         g[label_mask == ll] = label_colours[ll, 1]
27         b[label_mask == ll] = label_colours[ll, 2]
28     rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))
29     rgb[:, :, 0] = r / 255.0
30     rgb[:, :, 1] = g / 255.0
31     rgb[:, :, 2] = b / 255.0
32     if plot:
33         plt.imshow(rgb)
34         plt.show()
35     else:
36         return rgb

绘图的主函数在下面:

 

 1 if __name__ == '__main__':
 2     from dataloaders.utils import decode_segmap
 3     from torch.utils.data import DataLoader
 4     import matplotlib.pyplot as plt
 5     import argparse
 6 
 7     parser = argparse.ArgumentParser()
 8     args = parser.parse_args()
 9     args.base_size = 256
10     args.crop_size = 256
11 
12     voc_train = VOCSegmentation(args, split='train')
13 
14     dataloader = DataLoader(voc_train, batch_size=5, shuffle=True, num_workers=0)
15 
16 
17 
18     for ii, sample in enumerate(dataloader):
19         for jj in range(sample["image"].size()[0]):
20             img = sample['image'].numpy()
21             gt = sample['label'].numpy()
22             tmp = np.array(gt[jj]).astype(np.uint8)
23             segmap = decode_segmap(tmp, dataset='pascal')
24             img_tmp = np.transpose(img[jj], axes=[1, 2, 0])
25             img_tmp *= (0.229, 0.224, 0.225)
26             img_tmp += (0.485, 0.456, 0.406)
27             img_tmp *= 255.0
28             img_tmp = img_tmp.astype(np.uint8)
29             plt.figure()
30             plt.title('display')
31             plt.subplot(211)
32             plt.imshow(img_tmp)
33             plt.subplot(212)
34             plt.imshow(segmap)
35 
36         if ii == 1:
37             break
38 
39     plt.show(block=True)

 

posted @ 2019-04-15 15:58  you-wh  阅读(2510)  评论(0编辑  收藏  举报
Fork me on GitHub