pytorch finetune maskRcnn

根据Pytorch官方的Finetune MaskRcnn并训练自己的多类别数据集。 

一、Pytorch官方的MaskRCNN Finetune是根据行人数据集进行二分类(背景/行人)的实力分割

具体路径可以查看https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html

根据官方的demo使用labelme进行数据标记比较简单,因为类别只有2类,如果有多类别的话,对于mask数据集的制作要求比较高。现在就多类别,如何进行制作自己的数据集并进行Finetune进行介绍。

二、数据集准备

2.1 首先准备好自己的数据集。

我使用的数据集如下,用来分割Scratch和Stain两类。

将训练数据放入train中,json中放入labelme标记的结果,labelme_json放入json转化为mask的图像和其他一些备注文件,cv2_mask放入labelme_json中的文件。

2.2 labelme安装和标记

这里安装的是3.16.2版本,使用pip install labelme==3.16.2安装。

之后在安装labelme的环境下输入labelme打开

使用labelme标记数据时,同一个类型,也要标成多个类别,比如2只狗,要标为dog1,dog2.因为maskrcnn要为每一个实例创建一个mask。

如果都标为dog的话,两只狗的label像素值是相同的,会被认为是一个实例,也就只有一个boundingBox。

所以标记的时候,要按下面的方式标记,同一个图片中label必须不同,但是不同图片label可以相同。

img1: scratch1,scratch2,stain1.

img2 :scratch1,scratch2, scratch3,scratch4

img2 :scratch1,scratch2,stain1,stain2

img2 :stain1 , stain2

标记好之后会在json文件中创建对应的json数据。

2.3 将Json转化为 Mask 

使用 labelme_json_to_dataset + "json文件路径"执行。

在这里的字典的值就对应了类别所对应的标记的像素值。到时候我们训练的时候可以根据自己的需要去解析,比如我们 将像素值1,2的设为1,像素值3,4的设为2,而背景为0

 1 import argparse
 2 import json
 3 import os
 4 import os.path as osp
 5 import warnings
 6 import copy
 7 
 8 import numpy as np
 9 import PIL.Image
10 from skimage import io
11 import yaml
12 
13 from labelme import utils
14 
15 NAME_LABEL_MAP = {
16     '_background_': 0,
17     "scratch": 1,
18     "scratch2": 2,
19     "stain1": 3,
20     "stain2": 4,
21 }
22 
23 
24 
25 
26 def main():
27     parser = argparse.ArgumentParser()
28     parser.add_argument('json_file')
29     parser.add_argument('-o', '--out', default=None)
30     args = parser.parse_args()
31 
32     json_file = args.json_file
33 
34     list = os.listdir(json_file)
35     for i in range(0, len(list)):
36         path = os.path.join(json_file, list[i])
37         filename = list[i][:-5]       # .json
38         if os.path.isfile(path):
39             data = json.load(open(path))
40             img = utils.image.img_b64_to_arr(data['imageData'])
41             lbl, lbl_names = utils.shape.labelme_shapes_to_label(img.shape, data['shapes'])  # labelme_shapes_to_label
42 
43             a=np.unique(lbl)
44 
45             # modify labels according to NAME_LABEL_MAP
46             lbl_tmp = copy.copy(lbl)
47             for key_name in lbl_names:
48                 old_lbl_val = lbl_names[key_name]
49                 new_lbl_val = NAME_LABEL_MAP[key_name]
50                 lbl_tmp[lbl == old_lbl_val] = new_lbl_val
51             lbl_names_tmp = {}
52             for key_name in lbl_names:
53                 lbl_names_tmp[key_name] = NAME_LABEL_MAP[key_name]
54             b=np.unique(lbl_tmp)
55             # Assign the new label to lbl and lbl_names dict
56             lbl = np.array(lbl_tmp, dtype=np.int8)
57             c=np.unique(lbl)
58             lbl_names = lbl_names_tmp
59 
60             captions = ['%d: %s' % (l, name) for l, name in enumerate(lbl_names)]
61 
62 
63             #lbl_viz = utils.draw.draw_label(lbl, img, captions)
64             out_dir = osp.basename(list[i]).replace('.', '_')
65             out_dir = osp.join(osp.dirname(list[i]), out_dir)
66             if not osp.exists(out_dir):
67                 os.mkdir(out_dir)
68 
69             PIL.Image.fromarray(img).save(osp.join(out_dir, '{}.png'.format(filename)))
70             PIL.Image.fromarray(lbl.astype(np.uint8)).save(osp.join(out_dir, '{}_gt.png'.format(filename)))
71             #PIL.Image.fromarray(lbl_viz).save(osp.join(out_dir, '{}_viz.png'.format(filename)))
72 
73             with open(osp.join(out_dir, 'label_names.txt'), 'w') as f:
74                 for lbl_name in lbl_names:
75                     f.write(lbl_name + '\n')
76 
77             warnings.warn('info.yaml is being replaced by label_names.txt')
78             info = dict(label_names=lbl_names)
79             with open(osp.join(out_dir, 'info.yaml'), 'w') as f:
80                 yaml.safe_dump(info, f, default_flow_style=False)
81 
82             print('Saved to: %s' % out_dir)
83 
84 
85 if __name__ == '__main__':
86     main()

转换好之后的文件夹。

 2.4将各个路径下的mask.png合并到cv2_mask文件夹下。

 1 import os
 2 path=' ~imagesM\\labelme_json'
 3 files=os.listdir(path)
 4 for file in files:
 5     jpath=os.listdir(os.path.join(path,file))
 6     new=file[:-5]
 7     newnames=os.path.join('~imageM\\cv2_mask',new)
 8     filename=os.path.join(path,file,jpath[1])
 9     print(filename)
10     print(newnames)
11     os.rename(filename,newnames+'.png')

三、代码调试部分

3.1首先得调通官方的Finetune例子,之后对加载数据和训练部分进行一些调整。

3.1.2 class PennFudanDataset的修改。

 1 import os
 2 import torch
 3 import numpy as np
 4 import torch.utils.data
 5 from PIL import Image
 6 
 7 
 8 class PennFudanDataset(torch.utils.data.Dataset):
 9     def __init__(self, root, transforms=None):
10         self.root = root
11         self.transforms = transforms
12         # load all image files, sorting them to ensure that they are aligned
13         self.imgs = list(sorted(os.listdir(os.path.join(root, "PNGImages"))))
14         self.masks = list(sorted(os.listdir(os.path.join(root, "PedMasks"))))
15 
16     def __getitem__(self, idx):
17         # load images ad masks
18         img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
19         mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
20         img = Image.open(img_path).convert("RGB")
21         # note that we haven't converted the mask to RGB,
22         # because each color corresponds to a different instance with 0 being background
23         mask = Image.open(mask_path)
24 
25         mask = np.array(mask)
26         # instances are encoded as different colors
27         obj_ids = np.unique(mask)
28         # first id is the background, so remove it
29         obj_ids = obj_ids[1:]
30 
31         # split the color-encoded mask into a set of binary masks
32         masks = mask == obj_ids[:, None, None]
33 
34         # get bounding box coordinates for each mask
35         num_objs = len(obj_ids)
36         boxes = []
37         for i in range(num_objs):
38             pos = np.where(masks[i])
39             xmin = np.min(pos[1])
40             xmax = np.max(pos[1])
41             ymin = np.min(pos[0])
42             ymax = np.max(pos[0])
43             boxes.append([xmin, ymin, xmax, ymax])
44 
45         boxes = torch.as_tensor(boxes, dtype=torch.float32)
46         # there is only one class
47         #1,2 scratch 3,4 stain
48         #labels = torch.ones((num_objs,), dtype=torch.int64) //这句是Pytorch官方的代码,默认所有mask都是类别1。 这里我将1,2设为1。 3,4设为2 
49         labels = torch.from_numpy(obj_ids)
50         for i in range(len(obj_ids)):
51             #print(obj_ids[i] == 1)
52             if ((obj_ids[i]==1) | (obj_ids[i]==2)):
53                 labels[i]=1
54             else:
55                 labels[i]=2
56         labels=torch.tensor(labels,dtype=torch.int64)
57 
58 
59 
60         masks = torch.as_tensor(masks, dtype=torch.uint8)
61 
62         image_id = torch.tensor([idx])
63         area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
64         # suppose all instances are not crowd
65         iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
66 
67         target = {}
68         target["boxes"] = boxes
69         target["labels"] = labels
70         target["masks"] = masks
71         target["image_id"] = image_id
72         target["area"] = area
73         target["iscrowd"] = iscrowd
74 
75         if self.transforms is not None:
76             img, target = self.transforms(img, target)
77 
78         return img, target
79 
80     def __len__(self):
81         return len(self.imgs)

3.1.2修改train.py,将numclass修改为3。

修改3,是为了将FasterRCNN和MaskRCNN的的全连接层进行修改,因为maskrcnn_resnet50_fpn backbone是在COCO数据进行训练的,结果有91类。

 

3.2、训练好之后可以进行结果可视化。

3.2.1 predict代码如下

  1 from MyDataset import PennFudanDataset
  2 from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
  3 from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
  4 from train_utils.draw_box_utils import draw_box
  5 import torchvision
  6 import torch
  7 import os
  8 import json
  9 import matplotlib.pyplot as plt
 10 from torchvision import transforms
 11 from PIL import Image
 12 import numpy as np
 13 import cv2
 14 import random
 15 
 16 
 17 def get_instance_segmentation_model(num_classes):
 18     # load an instance segmentation model pre-trained on COCO
 19     model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
 20 
 21     # get the number of input features for the classifier
 22     in_features = model.roi_heads.box_predictor.cls_score.in_features
 23 
 24     # replace the pre-trained head with a new one
 25     model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
 26 
 27     # now get the number of input features for the mask classifier
 28     in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
 29     hidden_layer = 256
 30 
 31     # and replace the mask predictor with a new one
 32     model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
 33                                                        hidden_layer,
 34                                                        num_classes)
 35     return model
 36 
 37 
 38 def random_colour_masks(image):
 39     """
 40     random_colour_masks
 41       parameters:
 42         - image - predicted masks
 43       method:
 44         - the masks of each predicted object is given random colour for visualization
 45     """
 46     colours = [[0, 255, 0], [0, 0, 255], [255, 0, 0], [0, 255, 255], [255, 255, 0], [255, 0, 255], [80, 70, 180],
 47                [250, 80, 190], [245, 145, 50], [70, 150, 250], [50, 190, 190]]
 48     r = np.zeros_like(image).astype(np.uint8)
 49     g = np.zeros_like(image).astype(np.uint8)
 50     b = np.zeros_like(image).astype(np.uint8)
 51     r[image == 1], g[image == 1], b[image == 1] = colours[random.randrange(0, 10)]
 52     coloured_mask = np.stack([r, g, b], axis=2)
 53     # coloured_mask = np.stack([r, g, b], axis=3)
 54     # coloured_mask=np.squeeze(coloured_mask)
 55     return coloured_mask
 56 
 57 
 58 COCO_INSTANCE_CATEGORY_NAMES = [
 59     '__background__', 'scratch','stain'
 60 ]
 61 
 62 def main():
 63     # get devices
 64     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 65     print("using {} device.".format(device))
 66 
 67     # create model
 68     model = get_instance_segmentation_model(num_classes=3)
 69 
 70     # load train weights
 71     train_weights = "./save_weights/resNetFpn-model-9.pth"
 72     assert os.path.exists(train_weights), "{} file dose not exist.".format(train_weights)
 73     model.load_state_dict(torch.load(train_weights, map_location=device)["model"])
 74     model.to(device)
 75 
 76     # read class_indict
 77     label_json_path = 'classes.json'
 78     assert os.path.exists(label_json_path), "json file {} dose not exist.".format(label_json_path)
 79     json_file = open(label_json_path, 'r')
 80     class_dict = json.load(json_file)
 81     category_index = {v: k for k, v in class_dict.items()}
 82 
 83     test_data = PennFudanDataset("./Images")
 84 
 85     for raw_img,_ in test_data:
 86         # from pil image to tensor, do not normalize image
 87         data_transform = transforms.Compose([transforms.ToTensor()])
 88         img = data_transform(raw_img)
 89 
 90         # expand batch dimension
 91         #img = torch.unsqueeze(img, dim=0)
 92 
 93         # put the model in evaluation mode
 94         model.eval()
 95         with torch.no_grad():
 96             prediction = model([img.to(device)])[0]
 97 
 98             boxes=prediction["boxes"].to("cpu").numpy()
 99             labels=prediction["labels"].to("cpu").numpy()
100             scores=prediction["scores"].to("cpu").numpy()
101 
102             threshold=0.3
103             pred_list = [list(scores).index(x) for x in scores if x > threshold]
104             masks=(prediction['masks'] > 0.5).squeeze().detach().cpu().numpy()
105 
106             if(len(pred_list)!=0):
107                 masks = masks[:len(pred_list)]
108             else:
109                 masks=[]
110 
111             draw_box(raw_img,
112                      boxes,
113                      labels,
114                      scores,
115                      category_index,
116                      thresh=threshold,
117                      line_thickness=2)
118 
119             cv_img = cv2.cvtColor(np.asarray(raw_img),cv2.COLOR_RGB2BGR)
120             cv_img = cv2.cvtColor(cv_img,cv2.COLOR_BGR2RGB)
121             # cv2.imshow("OpenCV", cv_img)
122             # cv2.waitKey()
123             #
124             # img = cv2.imread(img_path)
125             # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
126             for i in range(len(masks)):
127                 rgb_mask = random_colour_masks(masks[i])
128                 cv_img = cv2.addWeighted(cv_img, 1, rgb_mask, 0.5, 0)
129 
130             plt.figure(figsize=(20, 30))
131             plt.imshow(cv_img)
132             plt.xticks([])
133             plt.yticks([])
134             plt.show()
135 
136             #plt.imshow(raw_img)
137             #plt.show()
138             # 保存预测的图片结果
139             #original_img.save("test_result.jpg")
140 
141 if __name__ == '__main__':
142     main()

draw_box_utils.py

 1 import collections
 2 import PIL.ImageDraw as ImageDraw
 3 import PIL.ImageFont as ImageFont
 4 import numpy as np
 5 
 6 STANDARD_COLORS = [
 7     'AliceBlue', 'Chartreuse', 'Aqua', 'Aquamarine', 'Azure', 'Beige', 'Bisque',
 8     'BlanchedAlmond', 'BlueViolet', 'BurlyWood', 'CadetBlue', 'AntiqueWhite',
 9     'Chocolate', 'Coral', 'CornflowerBlue', 'Cornsilk', 'Crimson', 'Cyan',
10     'DarkCyan', 'DarkGoldenRod', 'DarkGrey', 'DarkKhaki', 'DarkOrange',
11     'DarkOrchid', 'DarkSalmon', 'DarkSeaGreen', 'DarkTurquoise', 'DarkViolet',
12     'DeepPink', 'DeepSkyBlue', 'DodgerBlue', 'FireBrick', 'FloralWhite',
13     'ForestGreen', 'Fuchsia', 'Gainsboro', 'GhostWhite', 'Gold', 'GoldenRod',
14     'Salmon', 'Tan', 'HoneyDew', 'HotPink', 'IndianRed', 'Ivory', 'Khaki',
15     'Lavender', 'LavenderBlush', 'LawnGreen', 'LemonChiffon', 'LightBlue',
16     'LightCoral', 'LightCyan', 'LightGoldenRodYellow', 'LightGray', 'LightGrey',
17     'LightGreen', 'LightPink', 'LightSalmon', 'LightSeaGreen', 'LightSkyBlue',
18     'LightSlateGray', 'LightSlateGrey', 'LightSteelBlue', 'LightYellow', 'Lime',
19     'LimeGreen', 'Linen', 'Magenta', 'MediumAquaMarine', 'MediumOrchid',
20     'MediumPurple', 'MediumSeaGreen', 'MediumSlateBlue', 'MediumSpringGreen',
21     'MediumTurquoise', 'MediumVioletRed', 'MintCream', 'MistyRose', 'Moccasin',
22     'NavajoWhite', 'OldLace', 'Olive', 'OliveDrab', 'Orange', 'OrangeRed',
23     'Orchid', 'PaleGoldenRod', 'PaleGreen', 'PaleTurquoise', 'PaleVioletRed',
24     'PapayaWhip', 'PeachPuff', 'Peru', 'Pink', 'Plum', 'PowderBlue', 'Purple',
25     'Red', 'RosyBrown', 'RoyalBlue', 'SaddleBrown', 'Green', 'SandyBrown',
26     'SeaGreen', 'SeaShell', 'Sienna', 'Silver', 'SkyBlue', 'SlateBlue',
27     'SlateGray', 'SlateGrey', 'Snow', 'SpringGreen', 'SteelBlue', 'GreenYellow',
28     'Teal', 'Thistle', 'Tomato', 'Turquoise', 'Violet', 'Wheat', 'White',
29     'WhiteSmoke', 'Yellow', 'YellowGreen'
30 ]
31 
32 
33 def filter_low_thresh(boxes, scores, classes, category_index, thresh, box_to_display_str_map, box_to_color_map):
34     for i in range(boxes.shape[0]):
35         if scores[i] > thresh:
36             box = tuple(boxes[i].tolist())  # numpy -> list -> tuple
37             if classes[i] in category_index.keys():
38                 class_name = category_index[classes[i]]
39             else:
40                 class_name = 'N/A'
41             display_str = str(class_name)
42             display_str = '{}: {}%'.format(display_str, int(100 * scores[i]))
43             box_to_display_str_map[box].append(display_str)
44             box_to_color_map[box] = STANDARD_COLORS[
45                 classes[i] % len(STANDARD_COLORS)]
46         else:
47             break  # 网络输出概率已经排序过,当遇到一个不满足后面的肯定不满足
48 
49 
50 def draw_text(draw, box_to_display_str_map, box, left, right, top, bottom, color):
51     try:
52         font = ImageFont.truetype('arial.ttf', 10)
53     except IOError:
54         font = ImageFont.load_default()
55 
56     # If the total height of the display strings added to the top of the bounding
57     # box exceeds the top of the image, stack the strings below the bounding box
58     # instead of above.
59     display_str_heights = [font.getsize(ds)[1] for ds in box_to_display_str_map[box]]
60     # Each display_str has a top and bottom margin of 0.05x.
61     total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights)
62 
63     if top > total_display_str_height:
64         text_bottom = top
65     else:
66         text_bottom = bottom + total_display_str_height
67     # Reverse list and print from bottom to top.
68     for display_str in box_to_display_str_map[box][::-1]:
69         text_width, text_height = font.getsize(display_str)
70         margin = np.ceil(0.05 * text_height)
71         draw.rectangle([(left, text_bottom - text_height - 2 * margin),
72                         (left + text_width, text_bottom)], fill=color)
73         draw.text((left + margin, text_bottom - text_height - margin),
74                   display_str,
75                   fill='black',
76                   font=font)
77         text_bottom -= text_height - 2 * margin
78 
79 
80 def draw_box(image, boxes, classes, scores, category_index, thresh=0.5, line_thickness=8):
81     box_to_display_str_map = collections.defaultdict(list)
82     box_to_color_map = collections.defaultdict(str)
83 
84     filter_low_thresh(boxes, scores, classes, category_index, thresh, box_to_display_str_map, box_to_color_map)
85 
86     # Draw all boxes onto image.
87     draw = ImageDraw.Draw(image)
88     im_width, im_height = image.size
89     for box, color in box_to_color_map.items():
90         xmin, ymin, xmax, ymax = box
91         (left, right, top, bottom) = (xmin * 1, xmax * 1,
92                                       ymin * 1, ymax * 1)
93         draw.line([(left, top), (left, bottom), (right, bottom),
94                    (right, top), (left, top)], width=line_thickness, fill=color)
95         draw_text(draw, box_to_display_str_map, box, left, right, top, bottom, color)

3.2.2结果显示

   

 

 

 

posted @ 2020-12-30 18:51  荼离伤花  阅读(46)  评论(0编辑  收藏  举报