Detectron2 训练+测试 代码框架
1. run_predict.py
import torch, torchvision import detectron2 from detectron2.utils.logger import setup_logger setup_logger import numpy as np import os, json, cv2, random import matplotlib.pyplot as plt from detectron2 import model_zoo from detectron2.engine import DefaultPredictor from detectron2.config import get_cfg from detectron2.utils.visualizer import Visualizer from detectron2.data import MetadataCatalog, DatasetCatalog im = cv2.imread('./input.jpg') # cv2_imshow(im) cfg = get_cfg() # cfg.merge_from_file(model_zoo.get_config_file('COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml')) cfg.merge_from_file('../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml') cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url('COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml') cfg.MODEL.WEIGHTS = 'model_final_f10217.pkl' predictor = DefaultPredictor(cfg) outputs = predictor(im) # print(outputs['instances'].pred_classses) v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2) v = v.draw_instance_predictions(outputs['instances'].to('cpu')) plt.figure(figsize = (14, 10)) plt.imshow(cv2.cvtColor(v.get_image()[:, :, ::-1], cv2.COLOR_BGR2RGB)) plt.savefig('./output.jpg')
2. run_train.py
import torch, torchvision import detectron2 from detectron2.utils.logger import setup_logger setup_logger import numpy as np import os, json, cv2, random import matplotlib.pyplot as plt from detectron2 import model_zoo from detectron2.engine import DefaultPredictor from detectron2.config import get_cfg from detectron2.utils.visualizer import Visualizer from detectron2.data import MetadataCatalog, DatasetCatalog from detectron2.structures import BoxMode # if your dataset is in COCO format, this cell can be replaced by the following three lines: # from detectron2.data.datasets import register_coco_instances # register_coco_instances("my_dataset_train", {}, "json_annotation_train.json", "path/to/image/dir") # register_coco_instances("my_dataset_val", {}, "json_annotation_val.json", "path/to/image/dir") ################### READ Data #################### def get_balloon_dicts(img_dir): json_file = os.path.join(img_dir, "via_region_data.json") with open(json_file) as f: imgs_anns = json.load(f) dataset_dicts = [] for idx, v in enumerate(imgs_anns.values()): record = {} filename = os.path.join(img_dir, v["filename"]) height, width = cv2.imread(filename).shape[:2] record["file_name"] = filename record["image_id"] = idx record["height"] = height record["width"] = width annos = v["regions"] objs = [] for _, anno in annos.items(): assert not anno["region_attributes"] anno = anno["shape_attributes"] px = anno["all_points_x"] py = anno["all_points_y"] poly = [(x + 0.5, y + 0.5) for x, y in zip(px, py)] poly = [p for x in poly for p in x] obj = { "bbox": [np.min(px), np.min(py), np.max(px), np.max(py)], "bbox_mode": BoxMode.XYXY_ABS, "segmentation": [poly], "category_id": 0, } objs.append(obj) record["annotations"] = objs dataset_dicts.append(record) return dataset_dicts for d in ["train", "val"]: DatasetCatalog.register("balloon_" + d, lambda d=d: get_balloon_dicts("dataset/balloon/" + d)) MetadataCatalog.get("balloon_" + d).set(thing_classes=["balloon"]) balloon_metadata = MetadataCatalog.get("balloon_train") dataset_dicts = get_balloon_dicts("dataset/balloon/train") for d in random.sample(dataset_dicts, 3): img = cv2.imread(d["file_name"]) visualizer = Visualizer(img[:, :, ::-1], metadata=balloon_metadata, scale=0.5) out = visualizer.draw_dataset_dict(d) # cv2_imshow(out.get_image()[:, :, ::-1]) plt.figure(figsize = (14, 10)) plt.imshow(cv2.cvtColor(out.get_image()[:, :, ::-1], cv2.COLOR_BGR2RGB)) plt.savefig('./read.jpg') ################### Train #################### from detectron2.engine import DefaultTrainer cfg = get_cfg() cfg.merge_from_file('../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml') cfg.DATASETS.TRAIN = ("balloon_train",) cfg.DATASETS.TEST = () cfg.DATALOADER.NUM_WORKERS = 2 cfg.MODEL.WEIGHTS = 'model_final_f10217.pkl' cfg.SOLVER.IMS_PER_BATCH = 2 cfg.SOLVER.BASE_LR = 0.00025 cfg.SOLVER.MAX_ITER = 300 cfg.SOLVER.STEPS = [] cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 # (default: 512) cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 # only has one class (ballon) os.makedirs(cfg.OUTPUT_DIR, exist_ok=True) trainer = DefaultTrainer(cfg) trainer.resume_or_load(resume=False) trainer.train()
3. print_intermediate_features.py
import torch, torchvision import detectron2 from detectron2.utils.logger import setup_logger setup_logger import numpy as np import os, json, cv2, random import matplotlib.pyplot as plt from detectron2 import model_zoo from detectron2.engine import DefaultPredictor from detectron2.config import get_cfg from detectron2.utils.visualizer import Visualizer from detectron2.data import MetadataCatalog, DatasetCatalog from detectron2.modeling import build_model from detectron2.modeling import build_backbone from detectron2.checkpoint import DetectionCheckpointer from detectron2.structures import ImageList ########### 读取数据 ############ #im = cv2.imread('./input.jpg') # cv2_imshow(im) ########## 指定配置文件 ############# cfg = get_cfg() # cfg.merge_from_file(model_zoo.get_config_file('COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml')) cfg.merge_from_file('../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml') cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url('COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml') cfg.MODEL.WEIGHTS = 'model_final_a54504.pkl' # COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml ############# 处理输入图像:PIL转成tensor ########## image = cv2.imread('./input.jpg') height, width = image.shape[:2] image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) inputs = [{"image": image, "height": height, "width": width}] ############ 使用Models执行网络的一部分 ############### model = build_model(cfg) # returns a torch.nn.Module with random parameters DetectionCheckpointer(model).load('model_final_a54504.pkl') model.eval() with torch.no_grad(): images = model.preprocess_image(inputs) features = model.backbone(images.tensor) # outputs = model(image) # features = model.backbone(image) # features是一个dict: print(features.keys()) with open('./print_features.txt_segmentation', 'w+') as f: print("features是一个字典,key包括['p2', 'p3', 'p4', 'p5', 'p6']", features['p2'], file=f) #print(type(model.named_children())) #print(model.named_children()) ''' for name, child in model.named_children(): for i in child: print(i) #print(type(child)) #print(type(name)) ''' ''' child和name的type: <class 'detectron2.modeling.backbone.fpn.FPN'> <class 'str'> <class 'detectron2.modeling.proposal_generator.rpn.RPN'> <class 'str'> <class 'detectron2.modeling.roi_heads.roi_heads.StandardROIHeads'> <class 'str'> '''
如果这篇文章帮助到了你,你可以请作者喝一杯咖啡