MMdetection 代码阅读 RoI head: Shared2FCBBoxHead
要理解mmdetection的训练流程,首先得先理清楚mmcv中的runner和hook的作用。参考 https://zhuanlan.zhihu.com/p/369826931
训练/验证调用关系:
tools/train.py
-> api/train.py
-> runner.run()
-> EpochBasedRunner.train() / val() -> run_iter()
-> self.model.train_step() / self.model.val_step()
-> BaseDetector.train_step() -> forward()
-> 单阶段检测器 SingleStageDetector.forward_train()
-> 双阶段检测器 TwoStageDetector.forward_train()
双阶段检测器 TwoStageDetector.forward_train()
1. extract_feat(): backbone+FPN
2. self.rpn_head.forward_train()
3. self.roi_head.forward_train()
注意:build网络结构,都在构造函数中,通过register和build_from_cfg()的机制来完成模型实例化。
self.backbone = build_backbone(backbone) self.neck = build_neck(neck)
# 注意这里的构建的时候使用的train_cfg和test_cfg: rpn / rcnn
# rcnn和roi_head对应
rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn) self.rpn_head = build_head(rpn_head_) roi_head.update(train_cfg=rcnn_train_cfg) roi_head.update(test_cfg=test_cfg.rcnn) self.roi_head = build_head(roi_head)
Shared2FCBBoxHead
继承关系
-> ConvFCBBoxHead 分类和回归共享2层全连接2FC
-> BBoxHead