gluoncv 目标检测,训练自己的数据集
https://gluon-cv.mxnet.io/build/examples_datasets/detection_custom.html
官方提供两种方案,一种是lst文件,一种是xml文件(voc的格式);
voc 格式的标注有标注工具,但是你如果是json文件标注的信息,或者其他格式的,你就要转成voc格式的。
于是就选择第一种数据格式lst序列文件格式,格式很简单。
根据你自己的json或者其他格式文件转换一下。
import json import os import cv2 import numpy as np def write_line(img_path, im_shape, boxes, ids, idx): h, w, c = im_shape # for header, we use minimal length 2, plus width and height # with A: 4, B: 5, C: width, D: height A = 4 B = 5 C = w D = h # concat id and bboxes labels = np.hstack((ids.reshape(-1, 1), boxes)).astype('float') # normalized bboxes (recommanded) labels[:, (1, 3)] /= float(w) labels[:, (2, 4)] /= float(h) # flatten labels = labels.flatten().tolist() str_idx = [str(idx)] str_header = [str(x) for x in [A, B, C, D]] str_labels = [str(x) for x in labels] str_path = [img_path] line = '\t'.join(str_idx + str_header + str_labels + str_path) + '\n' return line files = os.listdir('train_front') json_url = [] cnt = 0 for file in files: tmp = os.listdir('train_front/'+file) for js in tmp: if js.endswith('json'): json_url.append('train_front/'+file+'/'+js) cnt+=1 print(cnt) fwtrain = open("train.lst","w") fwval = open("val.lst","w") first_flag = [] flag = True cnt = 0 cnt1 = 0 cnt2 = 0 for json_url_index in json_url: file = open(json_url_index,'r') for line in file: js = json.loads(line) if 'person' in js: boxes = [] ids = [] for i in range(len(js['person'])): if js['person'][i]['attrs']['ignore'] == 'yes' or js['person'][i]['attrs']['occlusion']== 'heavily_occluded' or js['person'][i]['attrs']['occlusion']== 'invisible': continue bbox = js['person'][i]['data'] url = '/mnt/hdfs-data-4/data/jian.yin/'+json_url_index[:-5]+'/'+js['image_key'] width = js['width'] height = js['height'] boxes.append(bbox) ids.append(0) print(url) print(bbox) if len(boxes) > 0: if flag: flag = False first_flag = boxes ids = np.array(ids) if cnt < 27853//2: line = write_line(url,(height,width,3),boxes,ids,cnt1) fwtrain.write(line) cnt1+=1 if cnt >= 27853//2: line = write_line(url, (height, width, 3), boxes, ids, cnt2) fwval.write(line) cnt2+=1 cnt += 1 fwtrain.close() fwval.close() print(first_flag)
lst文件就转换好了。
然后添加自己的数据集:
https://github.com/dmlc/gluon-cv/blob/master/scripts/detection/faster_rcnn/train_faster_rcnn.py#L73
这里不能直接套用前面的导入数据的过程。
按照教程给出的方式添加。投机取巧的验证方式,直接引用前面的。
或者不验证:https://github.com/dmlc/gluon-cv/blob/master/scripts/detection/faster_rcnn/train_faster_rcnn.py#L393 部分注释掉。
elif dataset.lower() == 'pedestrian': lst_dataset = LstDetection('train_val.lst',root=os.path.expanduser('.')) print(len(lst_dataset)) first_img = lst_dataset[0][0] print(first_img.shape) print(lst_dataset[0][1]) train_dataset = LstDetection('train.lst',root=os.path.expanduser('.')) val_dataset = LstDetection('val.lst',root=os.path.expanduser('.')) classs = ('pedestrian',) val_metric = VOC07MApMetric(iou_thresh=0.5,class_names=classs)
训练参数:
https://github.com/dmlc/gluon-cv/blob/master/scripts/detection/faster_rcnn/train_faster_rcnn.py#L73
添加自己的训练参数或者直接套用。
if args.dataset == 'voc' or args.dataset == 'pedestrian': args.epochs = int(args.epochs) if args.epochs else 20 args.lr_decay_epoch = args.lr_decay_epoch if args.lr_decay_epoch else '14,20' args.lr = float(args.lr) if args.lr else 0.001 args.lr_warmup = args.lr_warmup if args.lr_warmup else -1 args.wd = float(args.wd) if args.wd else 5e-4
model_zoo.py添加自己的数据集映射方案。这里如果是pip install gluoncv ,就要到site-package里面改。
https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/model_zoo.py#L32
'faster_rcnn_resnet50_v1b_pedestrian': faster_rcnn_resnet50_v1b_voc,