pytorch finetune maskRcnn
根据Pytorch官方的Finetune MaskRcnn并训练自己的多类别数据集。
一、Pytorch官方的MaskRCNN Finetune是根据行人数据集进行二分类(背景/行人)的实力分割
具体路径可以查看https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html
根据官方的demo使用labelme进行数据标记比较简单,因为类别只有2类,如果有多类别的话,对于mask数据集的制作要求比较高。现在就多类别,如何进行制作自己的数据集并进行Finetune进行介绍。
二、数据集准备
2.1 首先准备好自己的数据集。
我使用的数据集如下,用来分割Scratch和Stain两类。
将训练数据放入train中,json中放入labelme标记的结果,labelme_json放入json转化为mask的图像和其他一些备注文件,cv2_mask放入labelme_json中的文件。
2.2 labelme安装和标记
这里安装的是3.16.2版本,使用pip install labelme==3.16.2安装。
之后在安装labelme的环境下输入labelme打开
使用labelme标记数据时,同一个类型,也要标成多个类别,比如2只狗,要标为dog1,dog2.因为maskrcnn要为每一个实例创建一个mask。
如果都标为dog的话,两只狗的label像素值是相同的,会被认为是一个实例,也就只有一个boundingBox。
所以标记的时候,要按下面的方式标记,同一个图片中label必须不同,但是不同图片label可以相同。
img1: scratch1,scratch2,stain1.
img2 :scratch1,scratch2, scratch3,scratch4
img2 :scratch1,scratch2,stain1,stain2
img2 :stain1 , stain2
标记好之后会在json文件中创建对应的json数据。
2.3 将Json转化为 Mask
使用 labelme_json_to_dataset + "json文件路径"执行。
在这里的字典的值就对应了类别所对应的标记的像素值。到时候我们训练的时候可以根据自己的需要去解析,比如我们 将像素值1,2的设为1,像素值3,4的设为2,而背景为0
1 import argparse 2 import json 3 import os 4 import os.path as osp 5 import warnings 6 import copy 7 8 import numpy as np 9 import PIL.Image 10 from skimage import io 11 import yaml 12 13 from labelme import utils 14 15 NAME_LABEL_MAP = { 16 '_background_': 0, 17 "scratch": 1, 18 "scratch2": 2, 19 "stain1": 3, 20 "stain2": 4, 21 } 22 23 24 25 26 def main(): 27 parser = argparse.ArgumentParser() 28 parser.add_argument('json_file') 29 parser.add_argument('-o', '--out', default=None) 30 args = parser.parse_args() 31 32 json_file = args.json_file 33 34 list = os.listdir(json_file) 35 for i in range(0, len(list)): 36 path = os.path.join(json_file, list[i]) 37 filename = list[i][:-5] # .json 38 if os.path.isfile(path): 39 data = json.load(open(path)) 40 img = utils.image.img_b64_to_arr(data['imageData']) 41 lbl, lbl_names = utils.shape.labelme_shapes_to_label(img.shape, data['shapes']) # labelme_shapes_to_label 42 43 a=np.unique(lbl) 44 45 # modify labels according to NAME_LABEL_MAP 46 lbl_tmp = copy.copy(lbl) 47 for key_name in lbl_names: 48 old_lbl_val = lbl_names[key_name] 49 new_lbl_val = NAME_LABEL_MAP[key_name] 50 lbl_tmp[lbl == old_lbl_val] = new_lbl_val 51 lbl_names_tmp = {} 52 for key_name in lbl_names: 53 lbl_names_tmp[key_name] = NAME_LABEL_MAP[key_name] 54 b=np.unique(lbl_tmp) 55 # Assign the new label to lbl and lbl_names dict 56 lbl = np.array(lbl_tmp, dtype=np.int8) 57 c=np.unique(lbl) 58 lbl_names = lbl_names_tmp 59 60 captions = ['%d: %s' % (l, name) for l, name in enumerate(lbl_names)] 61 62 63 #lbl_viz = utils.draw.draw_label(lbl, img, captions) 64 out_dir = osp.basename(list[i]).replace('.', '_') 65 out_dir = osp.join(osp.dirname(list[i]), out_dir) 66 if not osp.exists(out_dir): 67 os.mkdir(out_dir) 68 69 PIL.Image.fromarray(img).save(osp.join(out_dir, '{}.png'.format(filename))) 70 PIL.Image.fromarray(lbl.astype(np.uint8)).save(osp.join(out_dir, '{}_gt.png'.format(filename))) 71 #PIL.Image.fromarray(lbl_viz).save(osp.join(out_dir, '{}_viz.png'.format(filename))) 72 73 with open(osp.join(out_dir, 'label_names.txt'), 'w') as f: 74 for lbl_name in lbl_names: 75 f.write(lbl_name + '\n') 76 77 warnings.warn('info.yaml is being replaced by label_names.txt') 78 info = dict(label_names=lbl_names) 79 with open(osp.join(out_dir, 'info.yaml'), 'w') as f: 80 yaml.safe_dump(info, f, default_flow_style=False) 81 82 print('Saved to: %s' % out_dir) 83 84 85 if __name__ == '__main__': 86 main()
转换好之后的文件夹。
2.4将各个路径下的mask.png合并到cv2_mask文件夹下。
1 import os 2 path=' ~imagesM\\labelme_json' 3 files=os.listdir(path) 4 for file in files: 5 jpath=os.listdir(os.path.join(path,file)) 6 new=file[:-5] 7 newnames=os.path.join('~imageM\\cv2_mask',new) 8 filename=os.path.join(path,file,jpath[1]) 9 print(filename) 10 print(newnames) 11 os.rename(filename,newnames+'.png')
三、代码调试部分
3.1首先得调通官方的Finetune例子,之后对加载数据和训练部分进行一些调整。
3.1.2 class PennFudanDataset的修改。
1 import os 2 import torch 3 import numpy as np 4 import torch.utils.data 5 from PIL import Image 6 7 8 class PennFudanDataset(torch.utils.data.Dataset): 9 def __init__(self, root, transforms=None): 10 self.root = root 11 self.transforms = transforms 12 # load all image files, sorting them to ensure that they are aligned 13 self.imgs = list(sorted(os.listdir(os.path.join(root, "PNGImages")))) 14 self.masks = list(sorted(os.listdir(os.path.join(root, "PedMasks")))) 15 16 def __getitem__(self, idx): 17 # load images ad masks 18 img_path = os.path.join(self.root, "PNGImages", self.imgs[idx]) 19 mask_path = os.path.join(self.root, "PedMasks", self.masks[idx]) 20 img = Image.open(img_path).convert("RGB") 21 # note that we haven't converted the mask to RGB, 22 # because each color corresponds to a different instance with 0 being background 23 mask = Image.open(mask_path) 24 25 mask = np.array(mask) 26 # instances are encoded as different colors 27 obj_ids = np.unique(mask) 28 # first id is the background, so remove it 29 obj_ids = obj_ids[1:] 30 31 # split the color-encoded mask into a set of binary masks 32 masks = mask == obj_ids[:, None, None] 33 34 # get bounding box coordinates for each mask 35 num_objs = len(obj_ids) 36 boxes = [] 37 for i in range(num_objs): 38 pos = np.where(masks[i]) 39 xmin = np.min(pos[1]) 40 xmax = np.max(pos[1]) 41 ymin = np.min(pos[0]) 42 ymax = np.max(pos[0]) 43 boxes.append([xmin, ymin, xmax, ymax]) 44 45 boxes = torch.as_tensor(boxes, dtype=torch.float32) 46 # there is only one class 47 #1,2 scratch 3,4 stain 48 #labels = torch.ones((num_objs,), dtype=torch.int64) //这句是Pytorch官方的代码,默认所有mask都是类别1。 这里我将1,2设为1。 3,4设为2 49 labels = torch.from_numpy(obj_ids) 50 for i in range(len(obj_ids)): 51 #print(obj_ids[i] == 1) 52 if ((obj_ids[i]==1) | (obj_ids[i]==2)): 53 labels[i]=1 54 else: 55 labels[i]=2 56 labels=torch.tensor(labels,dtype=torch.int64) 57 58 59 60 masks = torch.as_tensor(masks, dtype=torch.uint8) 61 62 image_id = torch.tensor([idx]) 63 area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) 64 # suppose all instances are not crowd 65 iscrowd = torch.zeros((num_objs,), dtype=torch.int64) 66 67 target = {} 68 target["boxes"] = boxes 69 target["labels"] = labels 70 target["masks"] = masks 71 target["image_id"] = image_id 72 target["area"] = area 73 target["iscrowd"] = iscrowd 74 75 if self.transforms is not None: 76 img, target = self.transforms(img, target) 77 78 return img, target 79 80 def __len__(self): 81 return len(self.imgs)
3.1.2修改train.py,将numclass修改为3。
修改3,是为了将FasterRCNN和MaskRCNN的的全连接层进行修改,因为maskrcnn_resnet50_fpn backbone是在COCO数据进行训练的,结果有91类。
3.2、训练好之后可以进行结果可视化。
3.2.1 predict代码如下
1 from MyDataset import PennFudanDataset 2 from torchvision.models.detection.faster_rcnn import FastRCNNPredictor 3 from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor 4 from train_utils.draw_box_utils import draw_box 5 import torchvision 6 import torch 7 import os 8 import json 9 import matplotlib.pyplot as plt 10 from torchvision import transforms 11 from PIL import Image 12 import numpy as np 13 import cv2 14 import random 15 16 17 def get_instance_segmentation_model(num_classes): 18 # load an instance segmentation model pre-trained on COCO 19 model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) 20 21 # get the number of input features for the classifier 22 in_features = model.roi_heads.box_predictor.cls_score.in_features 23 24 # replace the pre-trained head with a new one 25 model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) 26 27 # now get the number of input features for the mask classifier 28 in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels 29 hidden_layer = 256 30 31 # and replace the mask predictor with a new one 32 model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, 33 hidden_layer, 34 num_classes) 35 return model 36 37 38 def random_colour_masks(image): 39 """ 40 random_colour_masks 41 parameters: 42 - image - predicted masks 43 method: 44 - the masks of each predicted object is given random colour for visualization 45 """ 46 colours = [[0, 255, 0], [0, 0, 255], [255, 0, 0], [0, 255, 255], [255, 255, 0], [255, 0, 255], [80, 70, 180], 47 [250, 80, 190], [245, 145, 50], [70, 150, 250], [50, 190, 190]] 48 r = np.zeros_like(image).astype(np.uint8) 49 g = np.zeros_like(image).astype(np.uint8) 50 b = np.zeros_like(image).astype(np.uint8) 51 r[image == 1], g[image == 1], b[image == 1] = colours[random.randrange(0, 10)] 52 coloured_mask = np.stack([r, g, b], axis=2) 53 # coloured_mask = np.stack([r, g, b], axis=3) 54 # coloured_mask=np.squeeze(coloured_mask) 55 return coloured_mask 56 57 58 COCO_INSTANCE_CATEGORY_NAMES = [ 59 '__background__', 'scratch','stain' 60 ] 61 62 def main(): 63 # get devices 64 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 65 print("using {} device.".format(device)) 66 67 # create model 68 model = get_instance_segmentation_model(num_classes=3) 69 70 # load train weights 71 train_weights = "./save_weights/resNetFpn-model-9.pth" 72 assert os.path.exists(train_weights), "{} file dose not exist.".format(train_weights) 73 model.load_state_dict(torch.load(train_weights, map_location=device)["model"]) 74 model.to(device) 75 76 # read class_indict 77 label_json_path = 'classes.json' 78 assert os.path.exists(label_json_path), "json file {} dose not exist.".format(label_json_path) 79 json_file = open(label_json_path, 'r') 80 class_dict = json.load(json_file) 81 category_index = {v: k for k, v in class_dict.items()} 82 83 test_data = PennFudanDataset("./Images") 84 85 for raw_img,_ in test_data: 86 # from pil image to tensor, do not normalize image 87 data_transform = transforms.Compose([transforms.ToTensor()]) 88 img = data_transform(raw_img) 89 90 # expand batch dimension 91 #img = torch.unsqueeze(img, dim=0) 92 93 # put the model in evaluation mode 94 model.eval() 95 with torch.no_grad(): 96 prediction = model([img.to(device)])[0] 97 98 boxes=prediction["boxes"].to("cpu").numpy() 99 labels=prediction["labels"].to("cpu").numpy() 100 scores=prediction["scores"].to("cpu").numpy() 101 102 threshold=0.3 103 pred_list = [list(scores).index(x) for x in scores if x > threshold] 104 masks=(prediction['masks'] > 0.5).squeeze().detach().cpu().numpy() 105 106 if(len(pred_list)!=0): 107 masks = masks[:len(pred_list)] 108 else: 109 masks=[] 110 111 draw_box(raw_img, 112 boxes, 113 labels, 114 scores, 115 category_index, 116 thresh=threshold, 117 line_thickness=2) 118 119 cv_img = cv2.cvtColor(np.asarray(raw_img),cv2.COLOR_RGB2BGR) 120 cv_img = cv2.cvtColor(cv_img,cv2.COLOR_BGR2RGB) 121 # cv2.imshow("OpenCV", cv_img) 122 # cv2.waitKey() 123 # 124 # img = cv2.imread(img_path) 125 # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 126 for i in range(len(masks)): 127 rgb_mask = random_colour_masks(masks[i]) 128 cv_img = cv2.addWeighted(cv_img, 1, rgb_mask, 0.5, 0) 129 130 plt.figure(figsize=(20, 30)) 131 plt.imshow(cv_img) 132 plt.xticks([]) 133 plt.yticks([]) 134 plt.show() 135 136 #plt.imshow(raw_img) 137 #plt.show() 138 # 保存预测的图片结果 139 #original_img.save("test_result.jpg") 140 141 if __name__ == '__main__': 142 main()
draw_box_utils.py
1 import collections 2 import PIL.ImageDraw as ImageDraw 3 import PIL.ImageFont as ImageFont 4 import numpy as np 5 6 STANDARD_COLORS = [ 7 'AliceBlue', 'Chartreuse', 'Aqua', 'Aquamarine', 'Azure', 'Beige', 'Bisque', 8 'BlanchedAlmond', 'BlueViolet', 'BurlyWood', 'CadetBlue', 'AntiqueWhite', 9 'Chocolate', 'Coral', 'CornflowerBlue', 'Cornsilk', 'Crimson', 'Cyan', 10 'DarkCyan', 'DarkGoldenRod', 'DarkGrey', 'DarkKhaki', 'DarkOrange', 11 'DarkOrchid', 'DarkSalmon', 'DarkSeaGreen', 'DarkTurquoise', 'DarkViolet', 12 'DeepPink', 'DeepSkyBlue', 'DodgerBlue', 'FireBrick', 'FloralWhite', 13 'ForestGreen', 'Fuchsia', 'Gainsboro', 'GhostWhite', 'Gold', 'GoldenRod', 14 'Salmon', 'Tan', 'HoneyDew', 'HotPink', 'IndianRed', 'Ivory', 'Khaki', 15 'Lavender', 'LavenderBlush', 'LawnGreen', 'LemonChiffon', 'LightBlue', 16 'LightCoral', 'LightCyan', 'LightGoldenRodYellow', 'LightGray', 'LightGrey', 17 'LightGreen', 'LightPink', 'LightSalmon', 'LightSeaGreen', 'LightSkyBlue', 18 'LightSlateGray', 'LightSlateGrey', 'LightSteelBlue', 'LightYellow', 'Lime', 19 'LimeGreen', 'Linen', 'Magenta', 'MediumAquaMarine', 'MediumOrchid', 20 'MediumPurple', 'MediumSeaGreen', 'MediumSlateBlue', 'MediumSpringGreen', 21 'MediumTurquoise', 'MediumVioletRed', 'MintCream', 'MistyRose', 'Moccasin', 22 'NavajoWhite', 'OldLace', 'Olive', 'OliveDrab', 'Orange', 'OrangeRed', 23 'Orchid', 'PaleGoldenRod', 'PaleGreen', 'PaleTurquoise', 'PaleVioletRed', 24 'PapayaWhip', 'PeachPuff', 'Peru', 'Pink', 'Plum', 'PowderBlue', 'Purple', 25 'Red', 'RosyBrown', 'RoyalBlue', 'SaddleBrown', 'Green', 'SandyBrown', 26 'SeaGreen', 'SeaShell', 'Sienna', 'Silver', 'SkyBlue', 'SlateBlue', 27 'SlateGray', 'SlateGrey', 'Snow', 'SpringGreen', 'SteelBlue', 'GreenYellow', 28 'Teal', 'Thistle', 'Tomato', 'Turquoise', 'Violet', 'Wheat', 'White', 29 'WhiteSmoke', 'Yellow', 'YellowGreen' 30 ] 31 32 33 def filter_low_thresh(boxes, scores, classes, category_index, thresh, box_to_display_str_map, box_to_color_map): 34 for i in range(boxes.shape[0]): 35 if scores[i] > thresh: 36 box = tuple(boxes[i].tolist()) # numpy -> list -> tuple 37 if classes[i] in category_index.keys(): 38 class_name = category_index[classes[i]] 39 else: 40 class_name = 'N/A' 41 display_str = str(class_name) 42 display_str = '{}: {}%'.format(display_str, int(100 * scores[i])) 43 box_to_display_str_map[box].append(display_str) 44 box_to_color_map[box] = STANDARD_COLORS[ 45 classes[i] % len(STANDARD_COLORS)] 46 else: 47 break # 网络输出概率已经排序过,当遇到一个不满足后面的肯定不满足 48 49 50 def draw_text(draw, box_to_display_str_map, box, left, right, top, bottom, color): 51 try: 52 font = ImageFont.truetype('arial.ttf', 10) 53 except IOError: 54 font = ImageFont.load_default() 55 56 # If the total height of the display strings added to the top of the bounding 57 # box exceeds the top of the image, stack the strings below the bounding box 58 # instead of above. 59 display_str_heights = [font.getsize(ds)[1] for ds in box_to_display_str_map[box]] 60 # Each display_str has a top and bottom margin of 0.05x. 61 total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights) 62 63 if top > total_display_str_height: 64 text_bottom = top 65 else: 66 text_bottom = bottom + total_display_str_height 67 # Reverse list and print from bottom to top. 68 for display_str in box_to_display_str_map[box][::-1]: 69 text_width, text_height = font.getsize(display_str) 70 margin = np.ceil(0.05 * text_height) 71 draw.rectangle([(left, text_bottom - text_height - 2 * margin), 72 (left + text_width, text_bottom)], fill=color) 73 draw.text((left + margin, text_bottom - text_height - margin), 74 display_str, 75 fill='black', 76 font=font) 77 text_bottom -= text_height - 2 * margin 78 79 80 def draw_box(image, boxes, classes, scores, category_index, thresh=0.5, line_thickness=8): 81 box_to_display_str_map = collections.defaultdict(list) 82 box_to_color_map = collections.defaultdict(str) 83 84 filter_low_thresh(boxes, scores, classes, category_index, thresh, box_to_display_str_map, box_to_color_map) 85 86 # Draw all boxes onto image. 87 draw = ImageDraw.Draw(image) 88 im_width, im_height = image.size 89 for box, color in box_to_color_map.items(): 90 xmin, ymin, xmax, ymax = box 91 (left, right, top, bottom) = (xmin * 1, xmax * 1, 92 ymin * 1, ymax * 1) 93 draw.line([(left, top), (left, bottom), (right, bottom), 94 (right, top), (left, top)], width=line_thickness, fill=color) 95 draw_text(draw, box_to_display_str_map, box, left, right, top, bottom, color)
3.2.2结果显示