mmdetection模型构建及Registry注册器机制

        好久没有做目标检测了,最近突然又接到了检测任务,跟同事讨论时,发现自己竟然忘了很多细节,

于是想趁训练模型的间隙,重新梳理下目标检测。我选择了mmdetection来学习,除了目标检测本身,

这个框架中很多python的使用技巧和框架的设计模式也是值得学习。最近一年基本都在使用python,

希望能将这些技巧应用在以后的工作之中。mmdetection封装的很好方便使用,比如我想训练的

话只需如下的一条指令。在train.py中,通过build_detector来构建模型(参数来自 faster_rcnn_r50_fpn_1x_voc0712.py),

python tools/train.py  configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc0712.py

build_detector的定义如下,最后通过build_from_cfg来构建模型,这里看到了让人困惑的Registry.

from mmdet.cv_core.utils import Registry, build_from_cfg
from torch import nn

BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head')
LOSSES = Registry('loss')
DETECTORS = Registry('detector')

def build(cfg, registry, default_args=None):
    """Build a module.

    Args:
        cfg (dict, list[dict]): The config of modules, is is either a dict
            or a list of configs.
        registry (:obj:`Registry`): A registry the module belongs to.
        default_args (dict, optional): Default arguments to build the module.
            Defaults to None.

    Returns:
        nn.Module: A built nn module.
    """
    if isinstance(cfg, list):
        modules = [
            build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
        ]
        return nn.Sequential(*modules)
    else:
        return build_from_cfg(cfg, registry, default_args)


def build_detector(cfg, train_cfg=None, test_cfg=None):
    """Build detector."""
    return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))

       

一、Registry是干什么的

        Registry完成了从字符串到类的映射,这样模型信息、训练时的参数信息,只需要写入到一个配置文件里,然后使用注册器来实例化即可。

二、如何实现

        通过装饰器来实现。在mmcv/mmcv/registry.py中,我们看到了Registry类。其中完成字符串到类的映射,实际上就是下面的成员函数来实现的,核心代码就一句,将要注册的类添加到字典里,key为类的名字(字符串)。下面通过一个小例子,

 def _register_module(self, module_class, module_name=None, force=False):
        if not inspect.isclass(module_class):
            raise TypeError('module must be a class, '
                            f'but got {type(module_class)}')

        if module_name is None:
            module_name = module_class.__name__
        if not force and module_name in self._module_dict:
            raise KeyError(f'{module_name} is already registered '
                           f'in {self.name}')
        self._module_dict[module_name] = module_class

 来看看它的构建过程。在导入下面这个文件时,首先创建FRUIT实例,接着通过装饰器(这里是用成员函数装饰类)来注册Apple类,调用register_module,然后调用_register(注意:参数cls即为类Apple),最后调用_register_module完成Apple的添加。完成后,FRUIT就有了个字典成员:['Apple']=APPle。在build_from_cfg中,传入模型参数,即可通过FRUIT构建Apple的实例化对象。

class Registry():

    def __init__(self, name):
        self._name = name
        self._module_dict = dict()

    def _register_module(self, module_class, module_name, force):
        self._module_dict[module_name] = module_class


    def register_module(self, name=None, force=False, module=None):
        print('register module ...')
        def _register(cls):
            print('cls ', cls)
            self._register_module(
                module_class=cls, module_name=name, force=force)
            return cls

        return _register

FRUIT = Registry('fruit')

@FRUIT.register_module()
class Apple():
    def __init__(self, name):
        self.name = name

def build_from_cfg(cfg, registry, default_args=None):
   

    args = cfg.copy()

    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)

    obj_type = args.pop('type')
    if is_str(obj_type):
        obj_cls = registry.get(obj_type)
    
    return obj_cls(**args)

三、Registry在mmdetection中是如何构建模型的

          我们来看一下构建模型的流程:

        1、在train.py中通过build_detector构建模型,其中cfg.model, cfg.train_cfg如下,包括模型信息和训练信息。

model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) 

         

        2、最关键的部分来了。首先通过build_detector构建模型, 其中传入的DETECTORS是Registry的实例,在该实例中,包含了所有已经实现的检测器,如图。那么它是在哪里实现添加这些检测的类的呢?

def build_detector(cfg, train_cfg=None, test_cfg=None):
    """Build detector."""
    return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))

 

           看了前面那个小例子我们就能猜到,一定是在这些检测类上,用Registry对其进行了注册,看看faster rcnn的实现,证明了我们的猜想。这样只要

在定义这些类时,对其进行注册,那么就会自动加入到DETECTORS这个实例的成员字典里,非常的巧妙。当我们想实例化某个检测网络时,传入其字符名称

即可。

       既然都看到这里了,就进一步看看网络时如何继续构建的吧。mmdetection将网络分成了几个部分,backbone,head,neck等。在TwoStageDetector(

faster rcnn的基类)中,可以看到分别构建了这几个部分。head, neck, loss等,同样是通过Registry来注册实现的。最后就是将这几个部分组合起来即可。

@DETECTORS.register_module()
class TwoStageDetector(BaseDetector):
    """Base class for two-stage detectors.

    Two-stage detectors typically consisting of a region proposal network and a
    task-specific regression head.
    """

    def __init__(self,
                 backbone,
                 neck=None,
                 rpn_head=None,
                 roi_head=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):
        super(TwoStageDetector, self).__init__()
        self.backbone = build_backbone(backbone)

        if neck is not None:
            self.neck = build_neck(neck)

        if rpn_head is not None:
            rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None
            rpn_head_ = rpn_head.copy()
            rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn)
            self.rpn_head = build_head(rpn_head_)

        if roi_head is not None:
            # update train and test cfg here for now
            # TODO: refactor assigner & sampler
            rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None
            roi_head.update(train_cfg=rcnn_train_cfg)
            roi_head.update(test_cfg=test_cfg.rcnn)
            self.roi_head = build_head(roi_head)

        self.train_cfg = train_cfg
        self.test_cfg = test_cfg

        self.init_weights(pretrained=pretrained)

 

四、Registry的应用

          在我最近的一个数据处理的项目中,有三类数据,sample, measure 和image。如果我想得到某个数据类型的实例,我是通过if来

判断的。那如果数据类别很多呢?就像检测器这样有几十种,再用if就显得很蠢了。借用Registry机制,可以轻松解决这个问题。

 

 

 

 

        

 

              

      

       

      

posted @ 2021-01-25 20:16  牧马人夏峥  阅读(2723)  评论(0编辑  收藏  举报