mmdetection模型构建及Registry注册器机制
好久没有做目标检测了,最近突然又接到了检测任务,跟同事讨论时,发现自己竟然忘了很多细节,
于是想趁训练模型的间隙,重新梳理下目标检测。我选择了mmdetection来学习,除了目标检测本身,
这个框架中很多python的使用技巧和框架的设计模式也是值得学习。最近一年基本都在使用python,
希望能将这些技巧应用在以后的工作之中。mmdetection封装的很好,很方便使用,比如我想训练的
话只需如下的一条指令。在train.py中,通过build_detector来构建模型(参数来自 faster_rcnn_r50_fpn_1x_voc0712.py),
python tools/train.py configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc0712.py
build_detector的定义如下,最后通过build_from_cfg来构建模型,这里看到了让人困惑的Registry.
from mmdet.cv_core.utils import Registry, build_from_cfg
from torch import nn
BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head')
LOSSES = Registry('loss')
DETECTORS = Registry('detector')
def build(cfg, registry, default_args=None):
"""Build a module.
Args:
cfg (dict, list[dict]): The config of modules, is is either a dict
or a list of configs.
registry (:obj:`Registry`): A registry the module belongs to.
default_args (dict, optional): Default arguments to build the module.
Defaults to None.
Returns:
nn.Module: A built nn module.
"""
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return nn.Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)
def build_detector(cfg, train_cfg=None, test_cfg=None):
"""Build detector."""
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
一、Registry是干什么的
Registry完成了从字符串到类的映射,这样模型信息、训练时的参数信息,只需要写入到一个配置文件里,然后使用注册器来实例化即可。
二、如何实现
通过装饰器来实现。在mmcv/mmcv/registry.py中,我们看到了Registry类。其中完成字符串到类的映射,实际上就是下面的成员函数来实现的,核心代码就一句,将要注册的类添加到字典里,key为类的名字(字符串)。下面通过一个小例子,
def _register_module(self, module_class, module_name=None, force=False):
if not inspect.isclass(module_class):
raise TypeError('module must be a class, '
f'but got {type(module_class)}')
if module_name is None:
module_name = module_class.__name__
if not force and module_name in self._module_dict:
raise KeyError(f'{module_name} is already registered '
f'in {self.name}')
self._module_dict[module_name] = module_class
来看看它的构建过程。在导入下面这个文件时,首先创建FRUIT实例,接着通过装饰器(这里是用成员函数装饰类)来注册Apple类,调用register_module,然后调用_register(注意:参数cls即为类Apple),最后调用_register_module完成Apple的添加。完成后,FRUIT就有了个字典成员:['Apple']=APPle。在build_from_cfg中,传入模型参数,即可通过FRUIT构建Apple的实例化对象。
class Registry(): def __init__(self, name): self._name = name self._module_dict = dict() def _register_module(self, module_class, module_name, force): self._module_dict[module_name] = module_class def register_module(self, name=None, force=False, module=None): print('register module ...') def _register(cls): print('cls ', cls) self._register_module( module_class=cls, module_name=name, force=force) return cls return _register FRUIT = Registry('fruit') @FRUIT.register_module() class Apple(): def __init__(self, name): self.name = name
三、Registry在mmdetection中是如何构建模型的
我们来看一下构建模型的流程:
1、在train.py中通过build_detector构建模型,其中cfg.model, cfg.train_cfg如下,包括模型信息和训练信息。
model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
2、最关键的部分来了。首先通过build_detector构建模型, 其中传入的DETECTORS是Registry的实例,在该实例中,包含了所有已经实现的检测器,如图。那么它是在哪里实现添加这些检测的类的呢?
def build_detector(cfg, train_cfg=None, test_cfg=None):
"""Build detector."""
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
看了前面那个小例子我们就能猜到,一定是在这些检测类上,用Registry对其进行了注册,看看faster rcnn的实现,证明了我们的猜想。这样只要
在定义这些类时,对其进行注册,那么就会自动加入到DETECTORS这个实例的成员字典里,非常的巧妙。当我们想实例化某个检测网络时,传入其字符名称
即可。
既然都看到这里了,就进一步看看网络时如何继续构建的吧。mmdetection将网络分成了几个部分,backbone,head,neck等。在TwoStageDetector(
faster rcnn的基类)中,可以看到分别构建了这几个部分。head, neck, loss等,同样是通过Registry来注册实现的。最后就是将这几个部分组合起来即可。
@DETECTORS.register_module()
class TwoStageDetector(BaseDetector):
"""Base class for two-stage detectors.
Two-stage detectors typically consisting of a region proposal network and a
task-specific regression head.
"""
def __init__(self,
backbone,
neck=None,
rpn_head=None,
roi_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(TwoStageDetector, self).__init__()
self.backbone = build_backbone(backbone)
if neck is not None:
self.neck = build_neck(neck)
if rpn_head is not None:
rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None
rpn_head_ = rpn_head.copy()
rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn)
self.rpn_head = build_head(rpn_head_)
if roi_head is not None:
# update train and test cfg here for now
# TODO: refactor assigner & sampler
rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None
roi_head.update(train_cfg=rcnn_train_cfg)
roi_head.update(test_cfg=test_cfg.rcnn)
self.roi_head = build_head(roi_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
四、Registry的应用
在我最近的一个数据处理的项目中,有三类数据,sample, measure 和image。如果我想得到某个数据类型的实例,我是通过if来
判断的。那如果数据类别很多呢?就像检测器这样有几十种,再用if就显得很蠢了。借用Registry机制,可以轻松解决这个问题。