faster-rcnn代码阅读-数据预处理

毫无疑问,faster-rcnn是目标检测领域的一个里程碑式的算法。本文主要是本人阅读python版本的faster-rcnn代码的一个记录,算法的具体原理本文也会有介绍,但是为了对该算法有一个整体性的理解以及更好地理解本文,还需事先阅读faster-rcnn的论文并参考网上的一些说明性的博客(如一文读懂Faster RCNN)。官方的py-faster-rcnn代码库已经不再维护了,我使用的是经过少许修改后的代码(主要是numpy版本不兼容导致的一些错误),可以参考这里

faster-rcnn有2种训练方式,一是两阶段法,二是端到端的方法,本文主要讲述端到端的方法,并以训练代码的运行顺序进行阅读。

一、数据准备

程序首先从faster-rcnn/tools/train_net.py运行,程序如下:

 84 if __name__ == '__main__':
 85     args = parse_args()
 86 
 87     print('Called with args:')
 88     print(args)
 89 
 90     if args.cfg_file is not None:
 91         cfg_from_file(args.cfg_file)
 92     if args.set_cfgs is not None:
 93         cfg_from_list(args.set_cfgs)
 94 
 95     cfg.GPU_ID = args.gpu_id
 96 
 97     print('Using config:')
 98     pprint.pprint(cfg)
 99 
100     if not args.randomize:
101         # fix the random seeds (numpy and caffe) for reproducibility
102         np.random.seed(cfg.RNG_SEED)
103         caffe.set_random_seed(cfg.RNG_SEED)
104 
105     # set up caffe
106     caffe.set_mode_gpu()
107     caffe.set_device(args.gpu_id)
108 
109     imdb, roidb = combined_roidb(args.imdb_name)
110     print '{:d} roidb entries'.format(len(roidb))
111 
112     output_dir = get_output_dir(imdb)
113     print 'Output will be saved to `{:s}`'.format(output_dir)
114 
115     train_net(args.solver, roidb, output_dir,
116               pretrained_model=args.pretrained_model,
117               max_iters=args.max_iters)

该部分cfg_from_file(args.cfg_file)调用faster-rcnn/lib/fast_rcnn/config.py中的cfg_from_file方法,从faster-rcnn/experiments/cfgs/faster_rcnn_end2end.yml文件中加载一些端到端训练时用到的参数配置,这句话会修改config.py中一些参数的值,下面是faster_rcnn_end2end.yml中的内容:

 1 EXP_DIR: faster_rcnn_end2end
 2 TRAIN:
 3   HAS_RPN: True
 4   IMS_PER_BATCH: 1
 5   BBOX_NORMALIZE_TARGETS_PRECOMPUTED: True
 6   RPN_POSITIVE_OVERLAP: 0.7
 7   RPN_BATCHSIZE: 256
 8   PROPOSAL_METHOD: gt
 9   BG_THRESH_LO: 0.0
10 TEST:
11   HAS_RPN: True

train_net第109行imdb, roidb = combined_roidb(args.imdb_name)是数据准备的核心部分。返回的imdb是类pascal_voc的一个实例,后面只用到了其的一些路径,作用不大。roidb则包含了训练网络所需要的所有信息。下面看一下它的产生过程:

 1 def combined_roidb(imdb_names):
 2     def get_roidb(imdb_name):
 3         imdb = get_imdb(imdb_name)
 4         print 'Loaded dataset `{:s}` for training'.format(imdb.name)
 5         imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)
 6         print 'Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD)
 7         roidb = get_training_roidb(imdb)
 8         return roidb
 9 
10     roidbs = [get_roidb(s) for s in imdb_names.split('+')]
11     roidb = roidbs[0]
12     if len(roidbs) > 1:
13         for r in roidbs[1:]:
14             roidb.extend(r)
15         imdb = datasets.imdb.imdb(imdb_names)
16     else:
17         imdb = get_imdb(imdb_names)
18     return imdb, roidb

下面逐一分析combined_roidb函数中的每步操作。

1.1、get_imdb

首先由imdb = get_imdb(imdb_name)调用faster-rcnn/lib/datasets/factory.py中的get_imdb方法,返回了一个faster-rcnn/lib/datasets/pascal_voc.py中的pascal_voc类的实例。我输入函数get_imdb的参数是'voc_2007_trainval',与其对应的初始化pascal_voc类的参数为image_set='trainval',year='2007'。在这个pascal_voc实例中,数据集的路径由以下方式获取:

 1     def __init__(self, image_set, year, devkit_path=None):
 2         imdb.__init__(self, 'voc_' + year + '_' + image_set)
 3         self._year = year
 4         self._image_set = image_set
 5         self._devkit_path = self._get_default_path() if devkit_path is None \
 6                             else devkit_path
 7         self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year)
 8 
 9     def _get_default_path(self):
10         """
11         Return the default path where PASCAL VOC is expected to be installed.
12         """
13         return os.path.join(cfg.DATA_DIR, 'VOCdevkit' + self._year)

至于cfg.DATA_DIR,由faster-rcnn/lib/fast_rcnn/config.py文件的如下内容确定:

1 # Root directory of project
2 __C.ROOT_DIR = osp.abspath(osp.join(osp.dirname(__file__), '..', '..'))
3 
4 # Data directory
5 __C.DATA_DIR = osp.abspath(osp.join(__C.ROOT_DIR, 'data'))

因此,由给出的以上参数确定的数据集的路径为self._data_path=$CODE_DIR/faster-rcnn/data/VOCdevkit2007/VOC2007。

1.2、imdb.set_proposal_method

其次,imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)会调用faster-rcnn/lib/datasets/imdb.py中类imdb中的set_proposal_method方法(因为pascal_voc继承自imdb),进而使self.roidb_handler为类pascal_voc中的gt_roidb方法(因为参数method='gt')。

这步操作非常重要,因为函数gt_roidb就是读取pascal_voc数据集,并返回所有图片信息的函数,代码如下:

 1     def gt_roidb(self):
 2         """
 3         Return the database of ground-truth regions of interest.
 4 
 5         This function loads/saves from/to a cache file to speed up future calls.
 6         """
 7         cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
 8         if os.path.exists(cache_file):
 9             with open(cache_file, 'rb') as fid:
10                 roidb = cPickle.load(fid)
11             print '{} gt roidb loaded from {}'.format(self.name, cache_file)
12             return roidb
13 
14         gt_roidb = [self._load_pascal_annotation(index)
15                     for index in self.image_index]
16         with open(cache_file, 'wb') as fid:
17             cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL)
18         print 'wrote gt roidb to {}'.format(cache_file)
19 
20         return gt_roidb

在函数gt_roidb中,首先判断有没有cache_file(它会在第一次读取数据集标注文件之后将所有字典形式的标注信息写进一个文件中,我创建数据类用的imdb_name='voc_2007_trainval',因此对应的文件名为faster-rcnn/data/voc_2007_trainval_gt_roidb.pkl),若存在,则直接从中读取标注信息,若不存在,则通过调用_load_pascal_annotation将pascal_voc数据集中每张图片的标注信息读取读取到一个字典中,具体代码如下:

 1     def _load_pascal_annotation(self, index):
 2         """
 3         Load image and bounding boxes info from XML file in the PASCAL VOC
 4         format.
 5         """
 6         filename = os.path.join(self._data_path, 'Annotations', index + '.xml')
 7         tree = ET.parse(filename)
 8         objs = tree.findall('object')
 9         if not self.config['use_diff']:
10             # Exclude the samples labeled as difficult
11             non_diff_objs = [
12                 obj for obj in objs if int(obj.find('difficult').text) == 0]
13             # if len(non_diff_objs) != len(objs):
14             #     print 'Removed {} difficult objects'.format(
15             #         len(objs) - len(non_diff_objs))
16             objs = non_diff_objs
17         num_objs = len(objs)
18 
19         boxes = np.zeros((num_objs, 4), dtype=np.uint16)
20         gt_classes = np.zeros((num_objs), dtype=np.int32)
21         overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
22         # "Seg" area for pascal is just the box area
23         seg_areas = np.zeros((num_objs), dtype=np.float32)
24 
25         # Load object bounding boxes into a data frame.
26         for ix, obj in enumerate(objs):
27             bbox = obj.find('bndbox')
28             # Make pixel indexes 0-based
29             x1 = float(bbox.find('xmin').text) - 1
30             y1 = float(bbox.find('ymin').text) - 1
31             x2 = float(bbox.find('xmax').text) - 1
32             y2 = float(bbox.find('ymax').text) - 1
33             cls = self._class_to_ind[obj.find('name').text.lower().strip()]
34             boxes[ix, :] = [x1, y1, x2, y2]
35             gt_classes[ix] = cls
36             overlaps[ix, cls] = 1.0
37             seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)
38 
39         overlaps = scipy.sparse.csr_matrix(overlaps)
40 
41         return {'boxes' : boxes,
42                 'gt_classes': gt_classes,
43                 'gt_overlaps' : overlaps,
44                 'flipped' : False,
45                 'seg_areas' : seg_areas}

值得一提的是,字典中,overlaps指的是该张图片中,每个物体与其它ground true之间的重叠比例,不过从代码来看,默认一张图片中所有的物体(ground true)之间是没有重叠的,因而overlaps的shape为(num_objs, self.num_classes),它的每一行(第一个轴上)只有一个元素是1.0,其它的元素都是0。这种默认方式虽然与实际标注情况不符,但对后面的操作并没有影响。

1.3、get_training_roidb

roidb = get_training_roidb(imdb)会调用faster-rcnn/lib/fast_rcnn/train.py中的get_training_roidb函数

 1 def get_training_roidb(imdb):
 2     """Returns a roidb (Region of Interest database) for use in training."""
 3     if cfg.TRAIN.USE_FLIPPED:
 4         print 'Appending horizontally-flipped training examples...'
 5         imdb.append_flipped_images()
 6         print 'done'
 7 
 8     print 'Preparing training data...'
 9     rdl_roidb.prepare_roidb(imdb)
10     print 'done'
11 
12     return imdb.roidb

会进行2步操作。

1.3.1、imdb.append_flipped_images()

 1     def append_flipped_images(self):
 2         num_images = self.num_images
 3         widths = self._get_widths()
 4         for i in xrange(num_images):
 5             boxes = self.roidb[i]['boxes'].copy()
 6             oldx1 = boxes[:, 0].copy()
 7             oldx2 = boxes[:, 2].copy()
 8             boxes[:, 0] = widths[i] - oldx2 - 1
 9             boxes[:, 2] = widths[i] - oldx1 - 1
10             assert (boxes[:, 2] >= boxes[:, 0]).all()
11             entry = {'boxes' : boxes,
12                      'gt_overlaps' : self.roidb[i]['gt_overlaps'],
13                      'gt_classes' : self.roidb[i]['gt_classes'],
14                      'flipped' : True}
15             self.roidb.append(entry)
16         self._image_index = self._image_index * 2

此句调用faster-rcnn/lib/datasets/imdb.py中类imdb的append_flipped_images方法,其作用是将数据集中的每张图的所有bounding box标签进行水平翻转,然后将图片信息字典中的'flipped'置为True,并将这一新的字典添加进原始的roidb list中,这样图片信息列表的长度就变为了原来的2倍。最后将数据集实例中的_image_index成员(所有图片名的list)复制了一份,长度也变为了原来的2倍。值得关注的是self.roidb是类imdb的一个属性(由Python内置的@property装饰器修饰)。属性和方法的不同之处在于调用方法需要加(),如某方法名为methodname,调用方式为methodname(),而调用属性不需要加(),self.roidb的构造过程如以下代码所示。另外,装饰器@methodname.setter可以把一个方法变成可以赋值的属性,“=”右侧的表达式作为传入方法的实参,如以下代码中的@roidb_handler.setter。

 1     @property
 2     def roidb_handler(self):
 3         return self._roidb_handler
 4 
 5     @roidb_handler.setter
 6     def roidb_handler(self, val):
 7         self._roidb_handler = val
 8 
 9     def set_proposal_method(self, method):
10         method = eval('self.' + method + '_roidb')
11         self.roidb_handler = method
12 
13     @property
14     def roidb(self):
15         # A roidb is a list of dictionaries, each with the following keys:
16         #   boxes
17         #   gt_overlaps
18         #   gt_classes
19         #   flipped
20         if self._roidb is not None:
21             return self._roidb
22         self._roidb = self.roidb_handler()
23         return self._roidb

1.3.2、rdl_roidb.prepare_roidb(imdb)

 1 def prepare_roidb(imdb):
 2     """Enrich the imdb's roidb by adding some derived quantities that
 3     are useful for training. This function precomputes the maximum
 4     overlap, taken over ground-truth boxes, between each ROI and
 5     each ground-truth box. The class with maximum overlap is also
 6     recorded.
 7     """
 8     sizes = [PIL.Image.open(imdb.image_path_at(i)).size
 9              for i in xrange(imdb.num_images)]
10     roidb = imdb.roidb
11     for i in xrange(len(imdb.image_index)):
12         roidb[i]['image'] = imdb.image_path_at(i)
13         roidb[i]['width'] = sizes[i][0]
14         roidb[i]['height'] = sizes[i][1]
15         # need gt_overlaps as a dense array for argmax
16         gt_overlaps = roidb[i]['gt_overlaps'].toarray()
17         # max overlap with gt over classes (columns)
18         max_overlaps = gt_overlaps.max(axis=1)
19         # gt class that had the max overlap
20         max_classes = gt_overlaps.argmax(axis=1)
21         roidb[i]['max_classes'] = max_classes
22         roidb[i]['max_overlaps'] = max_overlaps
23         # sanity checks
24         # max overlap of 0 => class should be zero (background)
25         zero_inds = np.where(max_overlaps == 0)[0]
26         assert all(max_classes[zero_inds] == 0)
27         # max overlap > 0 => class should not be zero (must be a fg class)
28         nonzero_inds = np.where(max_overlaps > 0)[0]
29         assert all(max_classes[nonzero_inds] != 0)

此句调用faster-rcnn/lib/roi_data_layer/roidb.py中的prepare_roidb函数,其作用是在图片信息字典中加入5个键值。分别是'image'(图片的全路径),'width'(图片的宽度),'height'(图片的高度),'max_classes','max_overlaps'。

至此roidb的构造过程便结束了,下面总结一下:最终得到的roidb是一个包含数据集中所有图片(以及它的水平翻转)信息的list,每张图的信息(保存在一个字典中)对应着list中的一个元素。每张图片的信息结构如下:

 1 {
 2     'boxes' : boxes,                # picture's bounding box: xmin, ymin, xmax, ymax(pixel indexes 0-based),
 3                                     # shape: (num_objs, 4), dtype=np.uint16
 4     'gt_classes': gt_classes,       # gt class label(background is 0), shape: (num_objs,), dtype=np.int32
 5     'gt_overlaps' : overlaps,       # each obj's max overlap with one of gt, shape: (num_objs, self.num_classes), dtype=np.float32
 6     'flipped' : False,
 7     'seg_areas' : seg_areas,        # area for each obj in one picture, shape: (num_objs,), dtype=np.float32
 8     'image' : image_full_path,
 9     'width' : image_width,
10     'height' : image_height,
11     'max_classes' : max_classes,    # equal to gt_classes, shape: (num_objs,), dtype=np.int64
12     'max_overlaps' : max_overlaps,  # all elements are 1.0, shape: (num_objs,), dtype=np.float32
13 }

1.4、get_output_dir

train_net.py第112行output_dir = get_output_dir(imdb)调用faster-rcnn/lib/fast_rcnn/config.py中的get_output_dir函数

 1 def get_output_dir(imdb, net=None):
 2     """Return the directory where experimental artifacts are placed.
 3     If the directory does not exist, it is created.
 4 
 5     A canonical path is built using the name from an imdb and a network
 6     (if not None).
 7     """
 8     outdir = osp.abspath(osp.join(__C.ROOT_DIR, 'output', __C.EXP_DIR, imdb.name))
 9     if net is not None:
10         outdir = osp.join(outdir, net.name)
11     if not os.path.exists(outdir):
12         os.makedirs(outdir)
13     return outdir

函数中的__C.EXP_DIR在faster_rcnn_end2end.yml中的配置为faster_rcnn_end2end,因此最终outdir=$CODE_DIR/faster-rcnn/output/faster_rcnn_end2end/voc_2007_trainval

1.5、train_net

使用以上得到的roidb,output_dir等作为参数,训练网络。调用faster-rcnn/lib/fast_rcnn/train.py中的train_net函数

 1 def train_net(solver_prototxt, roidb, output_dir,
 2               pretrained_model=None, max_iters=40000):
 3     """Train a Fast R-CNN network."""
 4 
 5     roidb = filter_roidb(roidb)
 6     sw = SolverWrapper(solver_prototxt, roidb, output_dir,
 7                        pretrained_model=pretrained_model)
 8 
 9     print 'Solving...'
10     model_paths = sw.train_model(max_iters)
11     print 'done solving'
12     return model_paths

1.5.1、filter_roidb

roidb = filter_roidb(roidb)调用filter_roidb函数对上述得到的roidb再按照一定的要求作进一步的过滤:

 1 __C.TRAIN.FG_THRESH = 0.5
 2 __C.TRAIN.BG_THRESH_HI = 0.5
 3 __C.TRAIN.BG_THRESH_LO = 0.1
 4 
 5 def filter_roidb(roidb):
 6     """Remove roidb entries that have no usable RoIs."""
 7 
 8     def is_valid(entry):
 9         # Valid images have:
10         #   (1) At least one foreground RoI OR
11         #   (2) At least one background RoI
12         overlaps = entry['max_overlaps']
13         # find boxes with sufficient overlap
14         fg_inds = np.where(overlaps >= cfg.TRAIN.FG_THRESH)[0]
15         # Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI)
16         bg_inds = np.where((overlaps < cfg.TRAIN.BG_THRESH_HI) &
17                            (overlaps >= cfg.TRAIN.BG_THRESH_LO))[0]
18         # image is only valid if such boxes exist
19         valid = len(fg_inds) > 0 or len(bg_inds) > 0
20         return valid
21 
22     num = len(roidb)
23     filtered_roidb = [entry for entry in roidb if is_valid(entry)]
24     num_after = len(filtered_roidb)
25     print 'Filtered {} roidb entries: {} -> {}'.format(num - num_after,
26                                                        num, num_after)
27     return filtered_roidb

一般的标注信息都能满足上述2个要求。

1.5.2、SolverWrapper

在该类的初始化函数中,主要有以下操作:

函数中各配置参数的值如下:

cfg.TRAIN.HAS_RPN=True

cfg.TRAIN.BBOX_REG=True

cfg.TRAIN.BBOX_NORMALIZE_TARGETS=True

cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED=True

 1     def __init__(self, solver_prototxt, roidb, output_dir,
 2                  pretrained_model=None):
 3         """Initialize the SolverWrapper."""
 4         self.output_dir = output_dir
 5 
 6         if (cfg.TRAIN.HAS_RPN and cfg.TRAIN.BBOX_REG and
 7             cfg.TRAIN.BBOX_NORMALIZE_TARGETS):
 8             # RPN can only use precomputed normalization because there are no
 9             # fixed statistics to compute a priori
10             assert cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED
11 
12         if cfg.TRAIN.BBOX_REG:
13             print 'Computing bounding-box regression targets...'
14             self.bbox_means, self.bbox_stds = \
15                     rdl_roidb.add_bbox_regression_targets(roidb)
16             print 'done'
17 
18         self.solver = caffe.SGDSolver(solver_prototxt)
19         if pretrained_model is not None:
20             print ('Loading pretrained model '
21                    'weights from {:s}').format(pretrained_model)
22             self.solver.net.copy_from(pretrained_model)
23 
24         self.solver_param = caffe_pb2.SolverParameter()
25         with open(solver_prototxt, 'rt') as f:
26             pb2_text_format.Merge(f.read(), self.solver_param)
27 
28         self.solver.net.layers[0].set_roidb(roidb)

1.5.2.1、add_bbox_regression_targets

函数中self.bbox_means, self.bbox_stds = rdl_roidb.add_bbox_regression_targets(roidb)调用faster-rcnn/lib/roi_data_layer/roidb.py中的add_bbox_regression_targets函数

函数中各配置参数的值如下:

cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED=True

cfg.TRAIN.BBOX_NORMALIZE_TARGETS=True

 1 def add_bbox_regression_targets(roidb):
 2     """Add information needed to train bounding-box regressors."""
 3     assert len(roidb) > 0
 4     assert 'max_classes' in roidb[0], 'Did you call prepare_roidb first?'
 5 
 6     num_images = len(roidb)
 7     # Infer number of classes from the number of columns in gt_overlaps
 8     num_classes = roidb[0]['gt_overlaps'].shape[1]
 9     for im_i in xrange(num_images):
10         rois = roidb[im_i]['boxes']
11         max_overlaps = roidb[im_i]['max_overlaps']
12         max_classes = roidb[im_i]['max_classes']
13         roidb[im_i]['bbox_targets'] = \
14                 _compute_targets(rois, max_overlaps, max_classes)
15 
16     if cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED:
17         # Use fixed / precomputed "means" and "stds" instead of empirical values
18         means = np.tile(
19                 np.array(cfg.TRAIN.BBOX_NORMALIZE_MEANS), (num_classes, 1))
20         stds = np.tile(
21                 np.array(cfg.TRAIN.BBOX_NORMALIZE_STDS), (num_classes, 1))
22     else:
23         # Compute values needed for means and stds
24         # var(x) = E(x^2) - E(x)^2
25         class_counts = np.zeros((num_classes, 1)) + cfg.EPS
26         sums = np.zeros((num_classes, 4))
27         squared_sums = np.zeros((num_classes, 4))
28         for im_i in xrange(num_images):
29             targets = roidb[im_i]['bbox_targets']
30             for cls in xrange(1, num_classes):
31                 cls_inds = np.where(targets[:, 0] == cls)[0]
32                 if cls_inds.size > 0:
33                     class_counts[cls] += cls_inds.size
34                     sums[cls, :] += targets[cls_inds, 1:].sum(axis=0)
35                     squared_sums[cls, :] += \
36                             (targets[cls_inds, 1:] ** 2).sum(axis=0)
37 
38         means = sums / class_counts
39         stds = np.sqrt(squared_sums / class_counts - means ** 2)
40 
41     print 'bbox target means:'
42     print means
43     print means[1:, :].mean(axis=0) # ignore bg class
44     print 'bbox target stdevs:'
45     print stds
46     print stds[1:, :].mean(axis=0) # ignore bg class
47 
48     # Normalize targets
49     if cfg.TRAIN.BBOX_NORMALIZE_TARGETS:
50         print "Normalizing targets"
51         for im_i in xrange(num_images):
52             targets = roidb[im_i]['bbox_targets']
53             for cls in xrange(1, num_classes):
54                 cls_inds = np.where(targets[:, 0] == cls)[0]
55                 roidb[im_i]['bbox_targets'][cls_inds, 1:] -= means[cls, :]
56                 roidb[im_i]['bbox_targets'][cls_inds, 1:] /= stds[cls, :]
57     else:
58         print "NOT normalizing targets"
59 
60     # These values will be needed for making predictions
61     # (the predicts will need to be unnormalized and uncentered)
62     return means.ravel(), stds.ravel()

add_bbox_regression_targets首先计算所有边界框的回归目标(注意不是边界框的坐标),然后使用事先设定的均值和方差将回归目标标准化:

1 __C.TRAIN.BBOX_NORMALIZE_MEANS = (0.0, 0.0, 0.0, 0.0)
2 __C.TRAIN.BBOX_NORMALIZE_STDS = (0.1, 0.1, 0.2, 0.2)

因为从gt到gt的回归目标都为0,因此标准化之后仍然为0,我认为这一步有点多余。其中使用到的函数有_compute_targetsbbox_transform

1.5.2.2、set_roidb

在SolverWrapper的初始化函数中,接下来是构造一个caffe中的solver对象、加载与训练模型的参数。最后使用self.solver.net.layers[0].set_roidb(roidb)将上述的roidb传入网络的第一层,即input-data层中。set_roidb的具体代码如下:

函数中各配置参数的值如下:

cfg.TRAIN.USE_PREFETCH=False

cfg.TRAIN.ASPECT_GROUPING=True

 1     def set_roidb(self, roidb):
 2         """Set the roidb to be used by this layer during training."""
 3         self._roidb = roidb
 4         self._shuffle_roidb_inds()
 5         if cfg.TRAIN.USE_PREFETCH:
 6             self._blob_queue = Queue(10)
 7             self._prefetch_process = BlobFetcher(self._blob_queue,
 8                                                  self._roidb,
 9                                                  self._num_classes)
10             self._prefetch_process.start()
11             # Terminate the child process when the parent exists
12             def cleanup():
13                 print 'Terminating BlobFetcher'
14                 self._prefetch_process.terminate()
15                 self._prefetch_process.join()
16             import atexit
17             atexit.register(cleanup)
18 
19     def _shuffle_roidb_inds(self):
20         """Randomly permute the training roidb."""
21         if cfg.TRAIN.ASPECT_GROUPING:
22             widths = np.array([r['width'] for r in self._roidb])
23             heights = np.array([r['height'] for r in self._roidb])
24             horz = (widths >= heights)
25             vert = np.logical_not(horz)
26             horz_inds = np.where(horz)[0]
27             vert_inds = np.where(vert)[0]
28             inds = np.hstack((
29                 np.random.permutation(horz_inds),
30                 np.random.permutation(vert_inds)))
31             inds = np.reshape(inds, (-1, 2))
32             row_perm = np.random.permutation(np.arange(inds.shape[0]))
33             inds = np.reshape(inds[row_perm, :], (-1,))
34             self._perm = inds
35         else:
36             self._perm = np.random.permutation(np.arange(len(self._roidb)))
37         self._cur = 0

其中,使用到的函数有set_roidb_shuffle_roidb_inds。至此,faster-rcnn的数据准备阶段完成。

 

posted @ 2018-12-14 23:40  洗盏更酌  Views(2146)  Comments(0Edit  收藏  举报