Tensorflow版Faster RCNN源码解析(TFFRCNN) (01) demo.py(含argparse模块,numpy模块中的newaxis、hstack、vstack和np.where等)
本blog为github上CharlesShang/TFFRCNN版源码解析系列代码笔记
---------------个人学习笔记---------------
----------------本文作者疆--------------
------点击此处链接至博客园原文------
1.主函数调用函数执行顺序:
parse_args()解析运行参数(如--gpu 0 ...)--->get_network(args.demo_net)加载网络(factory.py中)得到net
--->tf内部机制创建sess和恢复网络模型等--->glob.glob('图像地址')返回im_names地址列表(glob.py中)--->逐张图像
循环调用demo(sess,net,im_name)
2.parse_args()函数返回args
parser = argparse.ArgumentParser(description='Faster R-CNN demo') # 新建一个解析对象 parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',default=0, type=int) # 含默认值 ... args = parser.parse_args() # 类内同名函数
# -*- coding:utf-8 -*- # Author: WUJiang # argparse模块功能测试 import argparse parser = argparse.ArgumentParser(description="test") parser.add_argument('--mode', dest='work_mode', default=0) # 别名、默认值 parser.add_argument('--day', dest='date', default=4) args = parser.parse_args() print(args) # Namespace(day=4, mode=0) # args.date或args.day为4 args.work_mode或args.mode为0
3.demo()函数的执行逻辑
demo(sess, net, image_name)--->读取图像调用im_detect(sess, net, im)返回scores和boxes(应注意其维度,R为boxes个数)(test.py中)其中耗时仅统计了im_detect函数耗时,未统计nms等处理耗时
--->设置CONF_THRESH得分阈值和NMS_THRESH阈值--->针对各个类别,构造该类的dets(R*5,R表示R个box,5=4坐标+1得分)
--->该类内执行nms(nms_wrapper.py中)(IoU阈值为0.3)--->vis_detections(im, cls, dets, ax, thresh=CONF_THRESH)对该类检测结果进行绘制,需用
到得分阈值(demo.py中)
def im_detect(sess, net, im, boxes=None): """Detect object classes in an image given object proposals. Arguments: net (caffe.Net): Fast R-CNN network to use im (ndarray): color image to test (in BGR order) boxes (ndarray): R x 4 array of object proposals Returns: scores (ndarray): R x K array of object class scores (K includes background as object category 0) boxes (ndarray): R x (4*K) array of predicted bounding boxes """
# 针对每个boxes,得到其属于各类的得分及按各类得到的回归boxes,若得分阈值设置较低,会看到图像某个目标被检测出多类超过阈值得分的box盒
CONF_THRESH = 0.8 NMS_THRESH = 0.3 for cls_ind, cls in enumerate(CLASSES[1:]): cls_ind += 1 # because we skipped background cls_boxes = boxes[:, 4 * cls_ind:4 * (cls_ind + 1)] # R*4 cls_scores = scores[:, cls_ind] # R*1 dets = np.hstack((cls_boxes, cls_scores[:, np.newaxis])).astype(np.float32) # R*5 keep = nms(dets, NMS_THRESH) dets = dets[keep, :] vis_detections(im, cls, dets, ax, thresh=CONF_THRESH)
# -*- coding:utf-8 -*- # Author: WUJiang # newaxis作用 import numpy as np a = np.array([1, 2, 3, 4, 5]) b = a[np.newaxis, :] c = a[:, np.newaxis] # [[1 2 3 4 5]] print(b) """ [[1] [2] [3] [4] [5]] """ print(c)
# -*- coding:utf-8 -*- # Author: WUJiang # 数组拼接 import numpy as np a = np.array([ [2, 2], [0, 3] ]) b = np.array([ [4, 1], [3, 1] ]) """ [[2 2 4 1] [0 3 3 1]] """ print(np.hstack((a, b))) """ [[2 2] [0 3] [4 1] [3 1]] """ print(np.vstack((a, b)))
4.demo()中图像目标检测时间的获取(Timer是在lib.utils.timer中定义的类)
from lib.utils.timer import Timer timer = Timer() timer.tic() scores, boxes = im_detect(sess, net, im) timer.toc()
实际上每次重新实例化timer,因此计算的时间即是t2-t1(简单地获取当时时间戳time.time()),而不是多张图像的平均检测时间
import time class Timer(object): """A simple timer.""" def __init__(self): self.total_time = 0. self.calls = 0 self.start_time = 0. self.diff = 0. self.average_time = 0. def tic(self): # using time.time instead of time.clock because time time.clock # does not normalize for multithreading self.start_time = time.time() def toc(self, average=True): self.diff = time.time() - self.start_time self.total_time += self.diff self.calls += 1 self.average_time = self.total_time / self.calls if average: return self.average_time else: return self.diff
5.vis_detections(im, cls, dets, ax, thresh=CONF_THRESH)函数(暂时忽略matplotlib.pyplot即plt模块相关绘制功能)的执行逻辑
得到dets中得分超过CONF_THRESH的索引inds,对于该类遍历各个超过CONF_THRESH的bbox进行绘制
inds = np.where(dets[:, -1] >= thresh)[0] if len(inds) == 0: return for i in inds: bbox = dets[i, :4] score = dets[i, -1] ...
# -*- coding:utf-8 -*- # Author: WUJiang # np.where测试 import numpy as np a = np.array([ [1, 2, 4, 5, 0.8], [2, 5, 7, 9, 0.9], [3, 6, 5, 20, 0.95] ]) # (array([1, 1, 2, 2], dtype=int64), array([2, 3, 1, 3], dtype=int64)) 对应于4个数组元素位置 # <class 'tuple'> b = np.where(a > 5) # [1 1 2 2] # <class 'numpy.ndarray'> b0 = b[0] print(type(b0)) # [0.8 0.9 0.95] print(a[:, 4]) # (array([1, 2], dtype=int64),) # # <class 'tuple'> c = np.where(a[:, 4] > 0.8) # [1 2] # <class 'numpy.ndarray'> print(c[0])