Detectron2训练visdrone记录

准备

VOC标签转换参见这篇
注意:object_name = name_dict[box[4]] 改为 object_name = name_dict[box[5]]。为了与detectron2统一,
标签文件夹命名为Annotations,图片文件夹命名为JPEGImages,train.txt位于xxx/ImageSets/Main/。

train

构建instance

# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import numpy as np
import os
import xml.etree.ElementTree as ET
from fvcore.common.file_io import PathManager

from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.structures import BoxMode

__all__ = ["register_visdrone_voc"]

CLASS_NAMES = ['__background__',  # always index 0
               'pedestrian', 'people', 'bicycle', 'car', 'van', 'truck', 'tricycle', 'awning-tricycle', 'bus', 'motor']


def load_voc_instances(dirname: str, split: str):
    """
    Load Pascal VOC detection annotations to Detectron2 format.

    Args:
        dirname: Contain "Annotations", "ImageSets", "JPEGImages"
        split (str): one of "train", "test", "val", "trainval"
    """
    with PathManager.open(os.path.join(dirname, "ImageSets", "Main", split + ".txt")) as f:
        fileids = np.loadtxt(f, dtype=np.str)

    # Needs to read many small annotation files. Makes sense at local
    annotation_dirname = PathManager.get_local_path(os.path.join(dirname, "Annotations/"))
    dicts = []
    for fileid in fileids:
        anno_file = os.path.join(annotation_dirname, fileid + ".xml")
        jpeg_file = os.path.join(dirname, "JPEGImages", fileid + ".jpg")

        with PathManager.open(anno_file) as f:
            tree = ET.parse(f)

        r = {
            "file_name": jpeg_file,
            "image_id": fileid,
            "height": int(tree.findall("./size/height")[0].text),
            "width": int(tree.findall("./size/width")[0].text),
        }
        instances = []

        for obj in tree.findall("object"):
            cls = obj.find("name").text
            # We include "difficult" samples in training.
            # Based on limited experiments, they don't hurt accuracy.
            # difficult = int(obj.find("difficult").text)
            # if difficult == 1:
            # continue
            bbox = obj.find("bndbox")
            bbox = [float(bbox.find(x).text) for x in ["xmin", "ymin", "xmax", "ymax"]]
            # Original annotations are integers in the range [1, W or H]
            # Assuming they mean 1-based pixel indices (inclusive),
            # a box with annotation (xmin=1, xmax=W) covers the whole image.
            # In coordinate space this is represented by (xmin=0, xmax=W)
            bbox[0] -= 1.0
            bbox[1] -= 1.0
            instances.append(
                {"category_id": CLASS_NAMES.index(cls), "bbox": bbox, "bbox_mode": BoxMode.XYXY_ABS}
            )
        r["annotations"] = instances
        dicts.append(r)
    return dicts


def register_visdrone_voc(name, dirname, split, year):
    DatasetCatalog.register(name, lambda: load_voc_instances(dirname, split))
    MetadataCatalog.get(name).set(
        thing_classes=CLASS_NAMES, dirname=dirname, year=year, split=split
    )

train采用Faster R-CNN with FPN,backbone使用Resnext-101,群卷积32x8d,即32个group,每个group8个filter,
注意,如果因image图片损坏无法训练,修改PIL库的ImageFile.py,将LOAD_TRUNCATED_IMAGES 改为 True

from detectron2.engine import DefaultTrainer
from detectron2.config import get_cfg
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.engine import DefaultPredictor
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.evaluation import COCOEvaluator, PascalVOCDetectionEvaluator, inference_on_dataset
from detectron2.data import build_detection_test_loader
from visdrone_voc import *
import os
import cv2
import torch

register_visdrone_voc('VISDRONE_VOC', os.path.join('/home/chenzhengxi/data/VisDrone/VisDrone2018-DET-train'),
                      'train', 2012)
register_visdrone_voc('VISDRONE_VAL', os.path.join('/home/chenzhengxi/data/VisDrone/VisDrone2018-DET-val'),
                      'val', 2012)
register_visdrone_voc('VISDRONE_TEST', os.path.join('/home/chenzhengxi/data/VisDrone/VisDrone2019-DET-test-dev'),
                      'test', 2012)
cfg = get_cfg()
cfg.merge_from_file('configs/faster_rcnn_X_101_32x8d_FPN_3x.yaml')

# cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 4
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg)
#resume=True可继续训练并加载最新权重
trainer.resume_or_load(resume=False)
trainer.train()

#以下代码可指定具体权重
#cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_0229999.pth")
#checkpointer = DetectionCheckpointer(trainer.model)
#checkpointer.load(cfg.MODEL.WEIGHTS)
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7   # set the testing threshold for this model

evaluator = PascalVOCDetectionEvaluator(cfg.DATASETS.TEST[0])
# val_loader = build_detection_test_loader(cfg, "VISDRONE_VAL")
# result_val = inference_on_dataset(trainer.model, val_loader, evaluator)
# print(result_val)
print(trainer.test(cfg, trainer.model, evaluator))

# predictor = DefaultPredictor(cfg)
# im = cv2.imread('/home/chenzhengxi/data/VisDrone/VisDrone2018-DET-val/JPEGImages/0000026_03500_d_0000031.jpg')
# outputs = predictor(im)
# ooo = outputs['instances'].to(torch.device("cpu"))
# boxes = ooo.pred_boxes.tensor.numpy()
# print(boxes)
# for i in range(len(boxes)):
#     cv2.rectangle(im, tuple(boxes[i, 0:2]), tuple(boxes[i, 2:4]), (0, 255, 0), 2)
#
# cv2.imshow('visdrone', im)
# cv2.waitKey(0)

测试记录

注意:数据集没有__background__,计算AP是会除0,修改pascal_voc_evaluation.py Line 90,

            for cls_id, cls_name in enumerate(self._class_names):
+++             if cls_id == 0:  # __background__
+++                 continue
                lines = predictions.get(cls_id, [""])
AP AP50 AP75 iter datasets
20.8397 39.0423 19.9213 94999 val
17.0480 32.8580 16.1755 94999 test
22.2992 40.0259 21.7078 169999 val
18.0168 33.5292 17.6131 169999 test
22.9258 41.0643 22.4904 214999 val
18.1249 33.7228 17.6461 214999 test
22.8556 40.9861 22.3667 269999 val
18.0256 33.5866 17.5259 269999 test
可以看出效果远高于yolo,最终配置和权重下载,提取码: 74s4

posted on 2020-06-08 16:03  haskell  阅读(1091)  评论(0编辑  收藏  举报