MMDetection源码解析:Faster RCNN(5)--TwoStageDetector类
TwoStageDetector类定义在\mmdet\models\detectors\tew_stage.py中:
import torch import torch.nn as nn # from mmdet.core import bbox2result, bbox2roi, build_assigner, build_sampler from ..builder import DETECTORS, build_backbone, build_head, build_neck from .base import BaseDetector @DETECTORS.register_module() class TwoStageDetector(BaseDetector): """Base class for two-stage detectors. Two-stage detectors typically consisting of a region proposal network and a task-specific regression head. """ def __init__(self, backbone, neck=None, rpn_head=None, roi_head=None, train_cfg=None, test_cfg=None, pretrained=None): super(TwoStageDetector, self).__init__() self.backbone = build_backbone(backbone) if neck is not None: self.neck = build_neck(neck) if rpn_head is not None: rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None rpn_head_ = rpn_head.copy() rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn) self.rpn_head = build_head(rpn_head_) if roi_head is not None: # update train and test cfg here for now # TODO: refactor assigner & sampler rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None roi_head.update(train_cfg=rcnn_train_cfg) roi_head.update(test_cfg=test_cfg.rcnn) self.roi_head = build_head(roi_head) self.train_cfg = train_cfg self.test_cfg = test_cfg self.init_weights(pretrained=pretrained) @property def with_rpn(self): """bool: whether the detector has RPN""" return hasattr(self, 'rpn_head') and self.rpn_head is not None @property def with_roi_head(self): """bool: whether the detector has a RoI head""" return hasattr(self, 'roi_head') and self.roi_head is not None def init_weights(self, pretrained=None): """Initialize the weights in detector. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ super(TwoStageDetector, self).init_weights(pretrained) self.backbone.init_weights(pretrained=pretrained) if self.with_neck: if isinstance(self.neck, nn.Sequential): for m in self.neck: m.init_weights() else: self.neck.init_weights() if self.with_rpn: self.rpn_head.init_weights() if self.with_roi_head: self.roi_head.init_weights(pretrained) def extract_feat(self, img): """Directly extract features from the backbone+neck.""" x = self.backbone(img) if self.with_neck: x = self.neck(x) return x def forward_dummy(self, img): """Used for computing network flops. See `mmdetection/tools/get_flops.py` """ outs = () # backbone x = self.extract_feat(img) # rpn if self.with_rpn: rpn_outs = self.rpn_head(x) outs = outs + (rpn_outs, ) proposals = torch.randn(1000, 4).to(img.device) # roi_head roi_outs = self.roi_head.forward_dummy(x, proposals) outs = outs + (roi_outs, ) return outs def forward_train(self, img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore=None, gt_masks=None, proposals=None, **kwargs): """ Args: img (Tensor): of shape (N, C, H, W) encoding input images. Typically these should be mean centered and std scaled. img_metas (list[dict]): list of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmdet/datasets/pipelines/formatting.py:Collect`. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. gt_labels (list[Tensor]): class indices corresponding to each box gt_bboxes_ignore (None | list[Tensor]): specify which bounding boxes can be ignored when computing the loss. gt_masks (None | Tensor) : true segmentation masks for each box used if the architecture supports a segmentation task. proposals : override rpn proposals with custom proposals. Use when `with_rpn` is False. Returns: dict[str, Tensor]: a dictionary of loss components """ x = self.extract_feat(img) losses = dict() # RPN forward and loss if self.with_rpn: proposal_cfg = self.train_cfg.get('rpn_proposal', self.test_cfg.rpn) rpn_losses, proposal_list = self.rpn_head.forward_train( x, img_metas, gt_bboxes, gt_labels=None, gt_bboxes_ignore=gt_bboxes_ignore, proposal_cfg=proposal_cfg) losses.update(rpn_losses) else: proposal_list = proposals roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list, gt_bboxes, gt_labels, gt_bboxes_ignore, gt_masks, **kwargs) losses.update(roi_losses) return losses async def async_simple_test(self, img, img_meta, proposals=None, rescale=False): """Async test without augmentation.""" assert self.with_bbox, 'Bbox head must be implemented.' x = self.extract_feat(img) if proposals is None: proposal_list = await self.rpn_head.async_simple_test_rpn( x, img_meta) else: proposal_list = proposals return await self.roi_head.async_simple_test( x, proposal_list, img_meta, rescale=rescale) def simple_test(self, img, img_metas, proposals=None, rescale=False): """Test without augmentation.""" assert self.with_bbox, 'Bbox head must be implemented.' x = self.extract_feat(img) if proposals is None: proposal_list = self.rpn_head.simple_test_rpn(x, img_metas) else: proposal_list = proposals return self.roi_head.simple_test( x, proposal_list, img_metas, rescale=rescale) def aug_test(self, imgs, img_metas, rescale=False): """Test with augmentations. If rescale is False, then returned bboxes and masks will fit the scale of imgs[0]. """ x = self.extract_feats(imgs) proposal_list = self.rpn_head.aug_test_rpn(x, img_metas) return self.roi_head.aug_test( x, proposal_list, img_metas, rescale=rescale)
TwoStageDetector继承自BaseDetector类,主要有以下函数:
(1) __init__():初始化函数,主要是对backbone,neck,rpn_head,roi_head等进行设置;
(2) init_weights():初始化参数值,包括,neck,rpn_head,roi_head的参数进行初始化;
(3) extract_feat():把输入的图像数据送入neck,并且得到输出的特征图;
(4) forward_train():主要有img,gt_bboxes,gt_labels几个参数,训练时的前向输出,
rpn_losses, proposal_list = self.rpn_head.forward_train( x, img_metas, gt_bboxes, gt_labels=None, gt_bboxes_ignore=gt_bboxes_ignore, proposal_cfg=proposal_cfg)
通过调用rpn_head的forward_train()函数计算RPN的损失函数值,并且得到proposal的列表proposal_list,
roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list,
gt_bboxes, gt_labels,
gt_bboxes_ignore, gt_masks,
**kwargs)
通过调用roi_head的forward_train()函数计算ROI的损失函数值.