Tensorflow版Faster RCNN源码解析(TFFRCNN) (15) VGGnet_train.py
本blog为github上CharlesShang/TFFRCNN版源码解析系列代码笔记
---------------个人学习笔记---------------
----------------本文作者疆--------------
------点击此处链接至博客园原文------
与VGGnet_test.py相比,VGGnet_train.py需要馈入更多的变量,与train.py中train_model(...)函数定义的feed_dict相照应,此外,还增加了name为rpn-data、roi-data、drop6和drop7的网络处理层,keep_prob为dropout的比例
# train.py中train_model(...)函数定义的feed_dict feed_dict={ self.net.data: blobs['data'], self.net.im_info: blobs['im_info'], self.net.keep_prob: 0.5, self.net.gt_boxes: blobs['gt_boxes'], self.net.gt_ishard: blobs['gt_ishard'], self.net.dontcare_areas: blobs['dontcare_areas'] }
VGGnet_train.py代码及注释如下:
import tensorflow as tf from network import Network from ..fast_rcnn.config import cfg class VGGnet_train(Network): # 基类为Network,重构了__init__() def __init__(self, trainable=True): # 定义的变量比VGGnet_test.py中要多 # 下一层输入(列表) self.inputs = [] self.data = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='data') self.im_info = tf.placeholder(tf.float32, shape=[None, 3], name='im_info') # 与train.py中train_model(...)函数定义的feed_dict照应 self.gt_boxes = tf.placeholder(tf.float32, shape=[None, 5], name='gt_boxes') self.gt_ishard = tf.placeholder(tf.int32, shape=[None], name='gt_ishard') self.dontcare_areas = tf.placeholder(tf.float32, shape=[None, 4], name='dontcare_areas') self.keep_prob = tf.placeholder(tf.float32) # 定义dropout的比例!!! # 各层输出(字典) self.layers = dict({'data': self.data, 'im_info': self.im_info, 'gt_boxes': self.gt_boxes,\ 'gt_ishard': self.gt_ishard, 'dontcare_areas': self.dontcare_areas}) self.trainable = trainable self.setup() def setup(self): # n_classes = 2 #2018.1.30 n_classes = cfg.NCLASSES # anchor_scales = [8, 16, 32] anchor_scales = cfg.ANCHOR_SCALES _feat_stride = [16, ] (self.feed('data') #feed最后会返回self,下一层可以直接.xxx # conv3_1后卷积核参数才被更新,之前层权值不变 .conv(3, 3, 64, 1, 1, name='conv1_1', trainable=False) .conv(3, 3, 64, 1, 1, name='conv1_2', trainable=False) .max_pool(2, 2, 2, 2, padding='VALID', name='pool1') .conv(3, 3, 128, 1, 1, name='conv2_1', trainable=False) .conv(3, 3, 128, 1, 1, name='conv2_2', trainable=False) .max_pool(2, 2, 2, 2, padding='VALID', name='pool2') .conv(3, 3, 256, 1, 1, name='conv3_1') .conv(3, 3, 256, 1, 1, name='conv3_2') .conv(3, 3, 256, 1, 1, name='conv3_3') .max_pool(2, 2, 2, 2, padding='VALID', name='pool3') .conv(3, 3, 512, 1, 1, name='conv4_1') .conv(3, 3, 512, 1, 1, name='conv4_2') .conv(3, 3, 512, 1, 1, name='conv4_3') .max_pool(2, 2, 2, 2, padding='VALID', name='pool4') .conv(3, 3, 512, 1, 1, name='conv5_1') .conv(3, 3, 512, 1, 1, name='conv5_2') .conv(3, 3, 512, 1, 1, name='conv5_3')) #========= RPN ============ (self.feed('conv5_3') .conv(3, 3, 512, 1, 1,name='rpn_conv/3x3')) # (1, H, W, A x 4) (self.feed('rpn_conv/3x3') .conv(1, 1, len(anchor_scales) * 3 * 4, 1, 1, padding='VALID', relu=False, name='rpn_bbox_pred')) # (1, H, W, A x 2) (self.feed('rpn_conv/3x3') .conv(1, 1, len(anchor_scales) * 3 * 2, 1, 1, padding='VALID', relu=False, name='rpn_cls_score')) # generating training labels on the fly 飞速写入 # output: rpn_labels(HxWxA, 2) rpn_bbox_targets(HxWxA, 4) rpn_bbox_inside_weights rpn_bbox_outside_weights # 相比于VGGnet_test.py多的网络层次!!! # Produces anchor classification labels and bounding-box regression targets. (self.feed('rpn_cls_score', 'gt_boxes', 'gt_ishard', 'dontcare_areas', 'im_info') .anchor_target_layer(_feat_stride, anchor_scales, name='rpn-data' )) # 先reshape后softmax再reshape回来 # shape is (1, H, W, Ax2) -> (1, H, WxA, 2) (self.feed('rpn_cls_score') .spatial_reshape_layer(2, name='rpn_cls_score_reshape') .spatial_softmax(name='rpn_cls_prob')) # shape is (1, H, WxA, 2) -> (1, H, W, Ax2) (self.feed('rpn_cls_prob') .spatial_reshape_layer(len(anchor_scales)*3*2, name='rpn_cls_prob_reshape')) # ========= RoI Proposal ============ # add the delta(output) to anchors then # choose some reasonabel boxes, considering scores, ratios, size and iou # rpn_rois <- (1 x H x W x A, 5) e.g. [0, x1, y1, x2, y2] # 回归后并经过一些后处理得到的proposal,见proposal_layer_tf.py # 默认_feat_stride = [16, ]、anchor_scales = cfg.ANCHOR_SCALES = [8, 16, 32]、TEST模式 (self.feed('rpn_cls_prob_reshape', 'rpn_bbox_pred', 'im_info') .proposal_layer(_feat_stride, anchor_scales, 'TRAIN', name='rpn_rois')) # 相比于VGGnet_test.py多的网络层次!!! # matching boxes and groundtruth and randomly sample some rois and labels for RCNN (self.feed('rpn_rois', 'gt_boxes', 'gt_ishard', 'dontcare_areas') .proposal_target_layer(n_classes, name='roi-data')) # ========= RCNN ============ (self.feed('conv5_3', 'rois') .roi_pool(7, 7, 1.0/16, name='pool_5') .fc(4096, name='fc6') .dropout(0.5, name='drop6') # 相比于VGGnet_test.py多的网络层次!!! .fc(4096, name='fc7') .dropout(0.5, name='drop7') # 相比于VGGnet_test.py多的网络层次!!! .fc(n_classes, relu=False, name='cls_score') .softmax(name='cls_prob')) (self.feed('drop7') .fc(n_classes*4, relu=False, name='bbox_pred'))