MMDetection源码解析:Faster RCNN(8)--BBoxHead类
BBoxHead类继承自nn.Module类,定义在\mmdet\models\roi_heads\bbox_heads\bbox_head.py中,其作用是输出ROI Pooling的分类和回归值.
import torch import torch.nn as nn import torch.nn.functional as F from mmcv.runner import auto_fp16, force_fp32 from torch.nn.modules.utils import _pair from mmdet.core import build_bbox_coder, multi_apply, multiclass_nms from mmdet.models.builder import HEADS, build_loss from mmdet.models.losses import accuracy @HEADS.register_module() class BBoxHead(nn.Module): """Simplest RoI head, with only two fc layers for classification and regression respectively.""" def __init__(self, with_avg_pool=False, with_cls=True, with_reg=True, roi_feat_size=7, in_channels=256, num_classes=80, bbox_coder=dict( type='DeltaXYWHBBoxCoder', target_means=[0., 0., 0., 0.], target_stds=[0.1, 0.1, 0.2, 0.2]), reg_class_agnostic=False, reg_decoded_bbox=False, loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), loss_bbox=dict( type='SmoothL1Loss', beta=1.0, loss_weight=1.0)): super(BBoxHead, self).__init__() assert with_cls or with_reg self.with_avg_pool = with_avg_pool self.with_cls = with_cls self.with_reg = with_reg self.roi_feat_size = _pair(roi_feat_size) self.roi_feat_area = self.roi_feat_size[0] * self.roi_feat_size[1] self.in_channels = in_channels self.num_classes = num_classes self.reg_class_agnostic = reg_class_agnostic self.reg_decoded_bbox = reg_decoded_bbox self.fp16_enabled = False self.bbox_coder = build_bbox_coder(bbox_coder) self.loss_cls = build_loss(loss_cls) self.loss_bbox = build_loss(loss_bbox) in_channels = self.in_channels if self.with_avg_pool: self.avg_pool = nn.AvgPool2d(self.roi_feat_size) else: in_channels *= self.roi_feat_area if self.with_cls: # need to add background class self.fc_cls = nn.Linear(in_channels, num_classes + 1) if self.with_reg: out_dim_reg = 4 if reg_class_agnostic else 4 * num_classes self.fc_reg = nn.Linear(in_channels, out_dim_reg) self.debug_imgs = None def init_weights(self): # conv layers are already initialized by ConvModule if self.with_cls: nn.init.normal_(self.fc_cls.weight, 0, 0.01) nn.init.constant_(self.fc_cls.bias, 0) if self.with_reg: nn.init.normal_(self.fc_reg.weight, 0, 0.001) nn.init.constant_(self.fc_reg.bias, 0) @auto_fp16() def forward(self, x): if self.with_avg_pool: x = self.avg_pool(x) x = x.view(x.size(0), -1) cls_score = self.fc_cls(x) if self.with_cls else None bbox_pred = self.fc_reg(x) if self.with_reg else None return cls_score, bbox_pred def _get_target_single(self, pos_bboxes, neg_bboxes, pos_gt_bboxes, pos_gt_labels, cfg): num_pos = pos_bboxes.size(0) num_neg = neg_bboxes.size(0) num_samples = num_pos + num_neg # original implementation uses new_zeros since BG are set to be 0 # now use empty & fill because BG cat_id = num_classes, # FG cat_id = [0, num_classes-1] labels = pos_bboxes.new_full((num_samples, ), self.num_classes, dtype=torch.long) label_weights = pos_bboxes.new_zeros(num_samples) bbox_targets = pos_bboxes.new_zeros(num_samples, 4) bbox_weights = pos_bboxes.new_zeros(num_samples, 4) if num_pos > 0: labels[:num_pos] = pos_gt_labels pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight label_weights[:num_pos] = pos_weight if not self.reg_decoded_bbox: pos_bbox_targets = self.bbox_coder.encode( pos_bboxes, pos_gt_bboxes) else: pos_bbox_targets = pos_gt_bboxes bbox_targets[:num_pos, :] = pos_bbox_targets bbox_weights[:num_pos, :] = 1 if num_neg > 0: label_weights[-num_neg:] = 1.0 return labels, label_weights, bbox_targets, bbox_weights def get_targets(self, sampling_results, gt_bboxes, gt_labels, rcnn_train_cfg, concat=True): pos_bboxes_list = [res.pos_bboxes for res in sampling_results] neg_bboxes_list = [res.neg_bboxes for res in sampling_results] pos_gt_bboxes_list = [res.pos_gt_bboxes for res in sampling_results] pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results] labels, label_weights, bbox_targets, bbox_weights = multi_apply( self._get_target_single, pos_bboxes_list, neg_bboxes_list, pos_gt_bboxes_list, pos_gt_labels_list, cfg=rcnn_train_cfg) if concat: labels = torch.cat(labels, 0) label_weights = torch.cat(label_weights, 0) bbox_targets = torch.cat(bbox_targets, 0) bbox_weights = torch.cat(bbox_weights, 0) return labels, label_weights, bbox_targets, bbox_weights @force_fp32(apply_to=('cls_score', 'bbox_pred')) def loss(self, cls_score, bbox_pred, rois, labels, label_weights, bbox_targets, bbox_weights, reduction_override=None): losses = dict() if cls_score is not None: avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.) if cls_score.numel() > 0: losses['loss_cls'] = self.loss_cls( cls_score, labels, label_weights, avg_factor=avg_factor, reduction_override=reduction_override) losses['acc'] = accuracy(cls_score, labels) if bbox_pred is not None: bg_class_ind = self.num_classes # 0~self.num_classes-1 are FG, self.num_classes is BG pos_inds = (labels >= 0) & (labels < bg_class_ind) # do not perform bounding box regression for BG anymore. if pos_inds.any(): if self.reg_decoded_bbox: bbox_pred = self.bbox_coder.decode(rois[:, 1:], bbox_pred) if self.reg_class_agnostic: pos_bbox_pred = bbox_pred.view( bbox_pred.size(0), 4)[pos_inds.type(torch.bool)] else: pos_bbox_pred = bbox_pred.view( bbox_pred.size(0), -1, 4)[pos_inds.type(torch.bool), labels[pos_inds.type(torch.bool)]] losses['loss_bbox'] = self.loss_bbox( pos_bbox_pred, bbox_targets[pos_inds.type(torch.bool)], bbox_weights[pos_inds.type(torch.bool)], avg_factor=bbox_targets.size(0), reduction_override=reduction_override) else: losses['loss_bbox'] = bbox_pred[pos_inds].sum() return losses @force_fp32(apply_to=('cls_score', 'bbox_pred')) def get_bboxes(self, rois, cls_score, bbox_pred, img_shape, scale_factor, rescale=False, cfg=None): if isinstance(cls_score, list): cls_score = sum(cls_score) / float(len(cls_score)) scores = F.softmax(cls_score, dim=1) if cls_score is not None else None if bbox_pred is not None: bboxes = self.bbox_coder.decode( rois[:, 1:], bbox_pred, max_shape=img_shape) else: bboxes = rois[:, 1:].clone() if img_shape is not None: bboxes[:, [0, 2]].clamp_(min=0, max=img_shape[1]) bboxes[:, [1, 3]].clamp_(min=0, max=img_shape[0]) if rescale and bboxes.size(0) > 0: if isinstance(scale_factor, float): bboxes /= scale_factor else: scale_factor = bboxes.new_tensor(scale_factor) bboxes = (bboxes.view(bboxes.size(0), -1, 4) / scale_factor).view(bboxes.size()[0], -1) if cfg is None: return bboxes, scores else: det_bboxes, det_labels = multiclass_nms(bboxes, scores, cfg.score_thr, cfg.nms, cfg.max_per_img) return det_bboxes, det_labels @force_fp32(apply_to=('bbox_preds', )) def refine_bboxes(self, rois, labels, bbox_preds, pos_is_gts, img_metas): """Refine bboxes during training. Args: rois (Tensor): Shape (n*bs, 5), where n is image number per GPU, and bs is the sampled RoIs per image. The first column is the image id and the next 4 columns are x1, y1, x2, y2. labels (Tensor): Shape (n*bs, ). bbox_preds (Tensor): Shape (n*bs, 4) or (n*bs, 4*#class). pos_is_gts (list[Tensor]): Flags indicating if each positive bbox is a gt bbox. img_metas (list[dict]): Meta info of each image. Returns: list[Tensor]: Refined bboxes of each image in a mini-batch. Example: >>> # xdoctest: +REQUIRES(module:kwarray) >>> import kwarray >>> import numpy as np >>> from mmdet.core.bbox.demodata import random_boxes >>> self = BBoxHead(reg_class_agnostic=True) >>> n_roi = 2 >>> n_img = 4 >>> scale = 512 >>> rng = np.random.RandomState(0) >>> img_metas = [{'img_shape': (scale, scale)} ... for _ in range(n_img)] >>> # Create rois in the expected format >>> roi_boxes = random_boxes(n_roi, scale=scale, rng=rng) >>> img_ids = torch.randint(0, n_img, (n_roi,)) >>> img_ids = img_ids.float() >>> rois = torch.cat([img_ids[:, None], roi_boxes], dim=1) >>> # Create other args >>> labels = torch.randint(0, 2, (n_roi,)).long() >>> bbox_preds = random_boxes(n_roi, scale=scale, rng=rng) >>> # For each image, pretend random positive boxes are gts >>> is_label_pos = (labels.numpy() > 0).astype(np.int) >>> lbl_per_img = kwarray.group_items(is_label_pos, ... img_ids.numpy()) >>> pos_per_img = [sum(lbl_per_img.get(gid, [])) ... for gid in range(n_img)] >>> pos_is_gts = [ >>> torch.randint(0, 2, (npos,)).byte().sort( >>> descending=True)[0] >>> for npos in pos_per_img >>> ] >>> bboxes_list = self.refine_bboxes(rois, labels, bbox_preds, >>> pos_is_gts, img_metas) >>> print(bboxes_list) """ img_ids = rois[:, 0].long().unique(sorted=True) assert img_ids.numel() <= len(img_metas) bboxes_list = [] for i in range(len(img_metas)): inds = torch.nonzero( rois[:, 0] == i, as_tuple=False).squeeze(dim=1) num_rois = inds.numel() bboxes_ = rois[inds, 1:] label_ = labels[inds] bbox_pred_ = bbox_preds[inds] img_meta_ = img_metas[i] pos_is_gts_ = pos_is_gts[i] bboxes = self.regress_by_class(bboxes_, label_, bbox_pred_, img_meta_) # filter gt bboxes pos_keep = 1 - pos_is_gts_ keep_inds = pos_is_gts_.new_ones(num_rois) keep_inds[:len(pos_is_gts_)] = pos_keep bboxes_list.append(bboxes[keep_inds.type(torch.bool)]) return bboxes_list @force_fp32(apply_to=('bbox_pred', )) def regress_by_class(self, rois, label, bbox_pred, img_meta): """Regress the bbox for the predicted class. Used in Cascade R-CNN. Args: rois (Tensor): shape (n, 4) or (n, 5) label (Tensor): shape (n, ) bbox_pred (Tensor): shape (n, 4*(#class)) or (n, 4) img_meta (dict): Image meta info. Returns: Tensor: Regressed bboxes, the same shape as input rois. """ assert rois.size(1) == 4 or rois.size(1) == 5, repr(rois.shape) if not self.reg_class_agnostic: label = label * 4 inds = torch.stack((label, label + 1, label + 2, label + 3), 1) bbox_pred = torch.gather(bbox_pred, 1, inds) assert bbox_pred.size(1) == 4 if rois.size(1) == 4: new_rois = self.bbox_coder.decode( rois, bbox_pred, max_shape=img_meta['img_shape']) else: bboxes = self.bbox_coder.decode( rois[:, 1:], bbox_pred, max_shape=img_meta['img_shape']) new_rois = torch.cat((rois[:, [0]], bboxes), dim=1) return new_rois
主要的函数有:
(1) __init__():初始化函数,主要参数包括POI Pooling的尺寸大小,输入通道数等等;
(2) get_targets():计算目标值,通过调用_get_target_single()实现;
(3) _get_target_single():计算FPN中每一个层次的目标值;
(4) get_bboxes():输出分类和回归值;
(5) forward():前向传播;
(6) loss():计算损失函数值.