Tensorflow版Faster RCNN源码解析(TFFRCNN) (01) demo.py(含argparse模块,numpy模块中的newaxis、hstack、vstack和np.where等)

本blog为github上CharlesShang/TFFRCNN版源码解析系列代码笔记

---------------个人学习笔记---------------

----------------本文作者疆--------------

------点击此处链接至博客园原文------

 

1.主函数调用函数执行顺序:

parse_args()解析运行参数(如--gpu 0 ...)--->get_network(args.demo_net)加载网络(factory.py中)得到net

--->tf内部机制创建sess和恢复网络模型等--->glob.glob('图像地址')返回im_names地址列表(glob.py中)--->逐张图像

循环调用demo(sess,net,im_name)

2.parse_args()函数返回args

parser = argparse.ArgumentParser(description='Faster R-CNN demo') # 新建一个解析对象
parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',default=0, type=int)  # 含默认值
...
args = parser.parse_args() # 类内同名函数
# -*- coding:utf-8 -*-
# Author: WUJiang
# argparse模块功能测试
import argparse

parser = argparse.ArgumentParser(description="test")
parser.add_argument('--mode', dest='work_mode', default=0)  # 别名、默认值
parser.add_argument('--day', dest='date', default=4)
args = parser.parse_args()
print(args)  # Namespace(day=4, mode=0)
# args.date或args.day为4 args.work_mode或args.mode为0
View Code

3.demo()函数的执行逻辑

demo(sess, net, image_name)--->读取图像调用im_detect(sess, net, im)返回scores和boxes(应注意其维度,R为boxes个数)(test.py中)其中耗时仅统计了im_detect函数耗时,未统计nms等处理耗时

--->设置CONF_THRESH得分阈值和NMS_THRESH阈值--->针对各个类别,构造该类的dets(R*5,R表示R个box,5=4坐标+1得分)

--->该类内执行nms(nms_wrapper.py中)(IoU阈值为0.3)--->vis_detections(im, cls, dets, ax, thresh=CONF_THRESH)对该类检测结果进行绘制,需用

到得分阈值(demo.py中)

def im_detect(sess, net, im, boxes=None):
    """Detect object classes in an image given object proposals.
    Arguments:
        net (caffe.Net): Fast R-CNN network to use
        im (ndarray): color image to test (in BGR order)
        boxes (ndarray): R x 4 array of object proposals
    Returns:
        scores (ndarray): R x K array of object class scores (K includes
            background as object category 0)
        boxes (ndarray): R x (4*K) array of predicted bounding boxes
    """
  # 针对每个boxes,得到其属于各类的得分及按各类得到的回归boxes,若得分阈值设置较低,会看到图像某个目标被检测出多类超过阈值得分的box盒
    CONF_THRESH = 0.8
    NMS_THRESH = 0.3
    for cls_ind, cls in enumerate(CLASSES[1:]):
        cls_ind += 1  # because we skipped background
        cls_boxes = boxes[:, 4 * cls_ind:4 * (cls_ind + 1)]    # R*4
        cls_scores = scores[:, cls_ind]         # R*1
        dets = np.hstack((cls_boxes,
                          cls_scores[:, np.newaxis])).astype(np.float32)  # R*5
        keep = nms(dets, NMS_THRESH)
        dets = dets[keep, :]
        vis_detections(im, cls, dets, ax, thresh=CONF_THRESH)
# -*- coding:utf-8 -*-
# Author: WUJiang
# newaxis作用
import numpy as np

a = np.array([1, 2, 3, 4, 5])
b = a[np.newaxis, :]
c = a[:, np.newaxis]
# [[1 2 3 4 5]]
print(b)
"""
[[1]
 [2]
 [3]
 [4]
 [5]]
 """
print(c)
View Code
# -*- coding:utf-8 -*-
# Author: WUJiang
# 数组拼接

import numpy as np

a = np.array([
    [2, 2],
    [0, 3]
])
b = np.array([
    [4, 1],
    [3, 1]
])
"""
[[2 2 4 1]
 [0 3 3 1]]
"""
print(np.hstack((a, b)))
"""
[[2 2]
 [0 3]
 [4 1]
 [3 1]]
"""
print(np.vstack((a, b)))
View Code

4.demo()中图像目标检测时间的获取(Timer是在lib.utils.timer中定义的类)

from lib.utils.timer import Timer
timer = Timer()
timer.tic()
scores, boxes = im_detect(sess, net, im)
timer.toc()

实际上每次重新实例化timer,因此计算的时间即是t2-t1(简单地获取当时时间戳time.time()),而不是多张图像的平均检测时间

import time

class Timer(object):
    """A simple timer."""
    def __init__(self):
        self.total_time = 0.
        self.calls = 0
        self.start_time = 0.
        self.diff = 0.
        self.average_time = 0.

    def tic(self):
        # using time.time instead of time.clock because time time.clock
        # does not normalize for multithreading
        self.start_time = time.time()

    def toc(self, average=True):
        self.diff = time.time() - self.start_time
        self.total_time += self.diff
        self.calls += 1
        self.average_time = self.total_time / self.calls
        if average:
            return self.average_time
        else:
            return self.diff
View Code

5.vis_detections(im, cls, dets, ax, thresh=CONF_THRESH)函数(暂时忽略matplotlib.pyplot即plt模块相关绘制功能)的执行逻辑

得到dets中得分超过CONF_THRESH的索引inds,对于该类遍历各个超过CONF_THRESH的bbox进行绘制

    inds = np.where(dets[:, -1] >= thresh)[0]
    if len(inds) == 0:
        return

    for i in inds:
        bbox = dets[i, :4]
        score = dets[i, -1]
...
View Code
# -*- coding:utf-8 -*-
# Author: WUJiang
# np.where测试

import numpy as np

a = np.array([
    [1, 2, 4, 5, 0.8],
    [2, 5, 7, 9, 0.9],
    [3, 6, 5, 20, 0.95]
])
# (array([1, 1, 2, 2], dtype=int64), array([2, 3, 1, 3], dtype=int64)) 对应于4个数组元素位置
# <class 'tuple'>
b = np.where(a > 5)
# [1 1 2 2]
# <class 'numpy.ndarray'>
b0 = b[0]
print(type(b0))
# [0.8  0.9  0.95]
print(a[:, 4])
# (array([1, 2], dtype=int64),)
# # <class 'tuple'>
c = np.where(a[:, 4] > 0.8)
# [1 2]
# <class 'numpy.ndarray'>
print(c[0])
View Code
posted @ 2019-06-27 10:29  JiangJ~  阅读(840)  评论(0编辑  收藏  举报