调用训练好的detectron模型
from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals from collections import defaultdict import cv2 # NOQA (Must import before importing caffe2 due to bug in cv2) from caffe2.python import workspace from detectron.core.config import assert_and_infer_cfg # from detectron.core.config import cfg from detectron.core.config import merge_cfg_from_file from detectron.utils.io import cache_url from detectron.utils.timer import Timer import detectron.core.test_engine as infer_engine import detectron.datasets.dummy_datasets as dummy_datasets import detectron.utils.c2 as c2_utils import detectron.utils.vis as vis_utils import numpy as np import pycocotools.mask as mask_util c2_utils.import_detectron_ops() # OpenCL may be enabled by default in OpenCV3; disable it because it's not # thread safe and causes unwanted GPU memory allocations. # cv2.ocl.setUseOpenCL(False) #coco # weights = "/home/gaomh/Desktop/test/cocomodel/model_final.pkl" # config = "/home/gaomh/Desktop/test/cocomodel/e2e_mask_rcnn_R-101-FPN_1x.yaml" #hat weights = "/home/gaomh/Desktop/test/models/kp-person/model_final.pkl" config = "/home/gaomh/Desktop/test/models/kp-person/e2e_keypoint_rcnn_X-101-32x8d-FPN_1x.yaml" #foot # weights = "/home/gaomh/Desktop/test/trainMOdel/train/voc_2007_train/retinanet/model_final.pkl" # config = "/home/gaomh/Desktop/test/trainMOdel/train/voc_2007_train/retinanet_R-50-FPN_1x.0.yaml" gpuid = 0 workspace.GlobalInit(['caffe2', '--caffe2_log_level=0']) merge_cfg_from_file(config) assert_and_infer_cfg(cache_urls=False) model = infer_engine.initialize_model_from_cfg(weights, gpuid) dataset = dummy_datasets.get_foot_dataset() def convert_from_cls_format(cls_boxes, cls_segms): """Convert from the class boxes/segms/keyps format generated by the testing code. """ box_list = [b for b in cls_boxes if len(b) > 0] if len(box_list) > 0: boxes = np.concatenate(box_list) else: boxes = None if cls_segms is not None: segms = [s for slist in cls_segms for s in slist] else: segms = None classes = [] for j in range(len(cls_boxes)): classes += [j] * len(cls_boxes[j]) return boxes, segms, classes def vis_one_image(boxes, cls_segms, thresh=0.9): """Visual debugging of detections.""" result_box = [] result_mask = [] if isinstance(boxes, list): boxes, segms, classes = convert_from_cls_format(boxes,cls_segms) if boxes is None or boxes.shape[0] == 0 or max(boxes[:, 4]) < thresh: return result_box,result_mask if segms is not None: masks=mask_util.decode(segms) # Display in largest to smallest order to reduce occlusion areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) sorted_inds = np.argsort(-areas) for i in sorted_inds: bbox = boxes[i, :4] score = boxes[i, -1] if score < thresh: continue result_box.append([dataset.classes[classes[i]], score, int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])]) if segms is not None and len(segms)>i: result_mask.append(masks[:,:,i]) else: result_mask.append([]) return result_box, result_mask # cap = cv2.VideoCapture("rtsp://192.168.123.231") cap = cv2.VideoCapture("/home/gaomh/per.mp4") # cv2.namedWindow("img", cv2.WINDOW_NORMAL) while cap.isOpened(): res, frame = cap.read() timers = defaultdict(Timer) with c2_utils.NamedCudaScope(0): cls_boxes, cls_segms, cls_keyps = infer_engine.im_detect_all( model, frame, None, timers=timers ) # img = vis_utils.vis_one_image_opencv(im=frame, boxes=cls_boxes, segms=cls_segms, keypoints=cls_keyps, thresh=0.7, kp_thresh=2, show_box=False # ,dataset=dataset, show_class=True) # vis_utils result_box, result_mask = vis_one_image(cls_boxes, cls_segms) print(result_box) for box in result_box: tit = box[0] thr = box[1] left = box[2] top = box[3] right = box[4] bottom = box[5] # if tit is "person": cv2.rectangle(frame, (left, top), (right, bottom), (255, 0, 0), 1) cv2.putText(frame, tit, (left-10, top-10), cv2.FONT_HERSHEY_COMPLEX, 0.4, (0, 0, 255)) # print(result_box) cv2.imshow("img", frame) key = cv2.waitKey(1) if key == ord("q"): break cv2.destroyAllWindows()
修改dummy_datasets.py,增加相应分类
def get_foot_dataset(): """A dummy COCO dataset that includes only the 'classes' field.""" ds = AttrDict() classes = [ '__background__', 'person', 'foot', 'car' ] ds.classes = {i: name for i, name in enumerate(classes)} return ds
效果图