tensorflow利用预训练模型进行目标检测(三):将检测结果存入mysql数据库
mysql版本:5.7 ; 数据库:rdshare;表captain_america3_sd用来记录某帧是否被检测。表captain_america3_d用来记录检测到的数据。
python模块,包部分内容参考http://www.runoob.com/python/python-modules.html https://www.cnblogs.com/ningskyer/articles/6025964.html
一、连接数据库
参考:
# 将视频插入数据库 def video_insert(filename,couse_id): conn =MySQLdb.connect(user='root',passwd='****',host='sh-cdb-myegtz7i.sql.tencentcdb.com',port=63619,db='bitbear',charset='utf8') cursor = conn.cursor() # 查找课程报告表中courseh_id等于解析得到的course_id的记录,得到courser_id # courseh_id是课程记录表中的course_id;courser_id是课程报告表中的主键;course_id是本程序中 sql="SELECT courser_id FROM course_report WHERE courseh_id ='%s' "% (couse_id); cursor.execute(sql) results = cursor.fetchall() if(results): print(results) courser_id=results[0][0] print(results[0][0]) # 获取该文件的路径 #rarpath = os.getcwd(); rarpath =filename print(rarpath) # 将记录插入 #try: sql="UPDATE course_report SET json = '%s' WHERE courser_id = '%s' " % (rarpath,courser_id) cursor.execute(sql) cursor.rowcount conn.commit() cursor.close()
首先需要安装mysql驱动 sudo apt-get install python-mysqldb
安装完成之后可以在Python解释器中测试一下
输入 import MySQLdb #注意大小写
如果不报错,就证明安装成功了。
简单测试版本
# 将detection的结果存入mysql数据库 def detection_to_database(object_name): conn =MySQLdb.connect(user='root',passwd='****',host='localhost',port=3306,db='rdshare',charset='utf8') cursor = conn.cursor() #sql="SELECT person FROM captain_america3_d WHERE id =1 "; #cursor.execute(sql) #results = cursor.fetchall() #if(results): # print(results) sql="INSERT INTO captain_america3_sd (is_detected) VALUES (1)" cursor.execute(sql) cursor.rowcount conn.commit() cursor.close()
二、修改文件结构
在同一目录下新建detection_control.py文件,相当于main文件,控制detection的流程,读入参数
#!usr/bin/python # -*- coding: utf-8 -*- import datetime import os import time import argparse import detection as mod_detection import sys reload(sys) sys.setdefaultencoding('utf8') os.environ['TF_CPP_MIN_LOG_LEVEL']='3' def parse_args(): '''parse args''' parser = argparse.ArgumentParser() parser.add_argument('--image_path', default='/home/yanjieliu/rdshare/dataset/ca36000_36100/') parser.add_argument('--image_start_num', default='36000') parser.add_argument('--image_end_num', default='36002') parser.add_argument('--model_name', default='ssd_inception_v2_coco_2018_01_28') return parser.parse_args() if __name__ == '__main__': # 运行 args=parse_args() for frame_num in range(int(args.image_start_num),int(args.image_end_num)): print(frame_num) #调用detection.py文件中的Detection函数,并向其传递参数 mod_detection.Detection(args, frame_num)
调用detection.py中的Detection函数,进行识别
detection.py文件内容如下
#!usr/bin/python # -*- coding: utf-8 -*- import numpy as np import matplotlib matplotlib.use('Agg') import matplotlib.pyplot from matplotlib import pyplot as plt import os import tensorflow as tf from PIL import Image from object_detection.utils import label_map_util from object_detection.utils import visualization_utils as vis_util import datetime # 关闭tensorflow警告 import time import MySQLdb import argparse import sys reload(sys) sys.setdefaultencoding('utf8') os.environ['TF_CPP_MIN_LOG_LEVEL']='3' detection_graph = tf.Graph() # 将detection的结果存入mysql数据库 def detection_to_database(object_name, frame_num): conn =MySQLdb.connect(user='root',passwd='****',host='localhost',port=3306,db='rdshare',charset='utf8') cursor = conn.cursor() #查询目标检测状态表,查看frame_num是否已经被检测过,若是,则更新,若否,则插入 sql="SELECT is_detected FROM captain_america3_sd WHERE frame_num ='%s' "% (frame_num); cursor.execute(sql) results = cursor.fetchall() if(results): print(results) sql="UPDATE captain_america3_sd SET is_detected=1"; else: print('null') sql="INSERT INTO captain_america3_sd (is_detected, frame_num) VALUES (1,'%s')"%(frame_num); cursor.execute(sql) cursor.rowcount conn.commit() cursor.close() # 加载模型数据------------------------------------------------------------------------------------------------------- def loading(model_name): with detection_graph.as_default(): od_graph_def = tf.GraphDef() PATH_TO_CKPT = '/home/yanjieliu/models/models/research/object_detection/pretrained_models/'+model_name + '/frozen_inference_graph.pb' with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name='') return detection_graph # Detection检测------------------------------------------------------------------------------------------------------- def load_image_into_numpy_array(image): (im_width, im_height) = image.size return np.array(image.getdata()).reshape( (im_height, im_width, 3)).astype(np.uint8) # List of the strings that is used to add correct label for each box. PATH_TO_LABELS = os.path.join('/home/yanjieliu/models/models/research/object_detection/data', 'mscoco_label_map.pbtxt') label_map = label_map_util.load_labelmap(PATH_TO_LABELS) categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=90, use_display_name=True) category_index = label_map_util.create_category_index(categories) def Detection(args, frame_num): image_path=args.image_path loading(args.model_name) #start = time.time() with detection_graph.as_default(): with tf.Session(graph=detection_graph) as sess: # for image_path in TEST_IMAGE_PATHS: image = Image.open('%simage-%s.jpeg'%(image_path, frame_num)) # the array based representation of the image will be used later in order to prepare the # result image with boxes and labels on it. image_np = load_image_into_numpy_array(image) # Expand dimensions since the model expects images to have shape: [1, None, None, 3] image_np_expanded = np.expand_dims(image_np, axis=0) image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') # Each box represents a part of the image where a particular object was detected. boxes = detection_graph.get_tensor_by_name('detection_boxes:0') # Each score represent how level of confidence for each of the objects. # Score is shown on the result image, together with the class label. scores = detection_graph.get_tensor_by_name('detection_scores:0') classes = detection_graph.get_tensor_by_name('detection_classes:0') num_detections = detection_graph.get_tensor_by_name('num_detections:0') # Actual detection. (boxes, scores, classes, num_detections) = sess.run( [boxes, scores, classes, num_detections], feed_dict={image_tensor: image_np_expanded}) # Visualization of the results of a detection.将识别结果标记在图片上 vis_util.visualize_boxes_and_labels_on_image_array( image_np, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores), category_index, use_normalized_coordinates=True, line_thickness=8) # output result输出 for i in range(3): if classes[0][i] in category_index.keys(): class_name = category_index[classes[0][i]]['name'] detection_to_database(class_name, frame_num) else: class_name = 'N/A' print("object:%s gailv:%s" % (class_name, scores[0][i])) # matplotlib输出图片 # Size, in inches, of the output images. IMAGE_SIZE = (20, 12) plt.figure(figsize=IMAGE_SIZE) plt.imshow(image_np) plt.show() def parse_args(): '''parse args''' parser = argparse.ArgumentParser() parser.add_argument('--image_path', default='/home/yanjieliu/rdshare/dataset/ca36000_36100/') parser.add_argument('--image_start_num', default='36000') parser.add_argument('--image_end_num', default='36002') parser.add_argument('--model_name', default='ssd_inception_v2_coco_2018_01_28') return parser.parse_args() if __name__ == '__main__': # 运行 args=parse_args() start = time.time() Detection(args, frame_num) end = time.time() print('time:\n') print str(end-start) #将时间写入到文件,方便统计 # with open('./outputs/1to10test_outputs.txt', 'a') as f: # f.write('\n') # f.write(str(end-start))