tensorflow高效地推导pb模型,完整代码

from matplotlib import pyplot as plt
import numpy as np
import os

import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile
import glob
from collections import defaultdict
from io import StringIO

from PIL import Image
import DrawBox
import cv2
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util

SaveFilename='D:/FasterR-CNNImageTest/11/'
if not os.path.exists(SaveFilename):
    os.mkdir(SaveFilename)
with tf.device('/cpu:0'):
    cap = cv2.VideoCapture(0)
    PATH_TO_CKPT = 'pb/frozen_inference_graph.pb'
    PATH_TO_LABELS = os.path.join('dataset', 'pascal_label_map.pbtxt')
    NUM_CLASSES = 2
    # Load a (frozen) Tensorflow model into memory.
    detection_graph = tf.Graph()
    with detection_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')
    label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
    categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES,
                                                                use_display_name=True)
    category_index = label_map_util.create_category_index(categories)


    def load_image_into_numpy_array(image):
        (im_width, im_height) = image.size
        return np.array(image.getdata()).reshape(
            (im_height, im_width, 3)).astype(np.uint8)


    # # Detection
    cnt = 0
    PATH_TO_TEST_IMAGES_DIR = 'E:/PythonOpenCVCode/BalanceGeneratePicture/TestSetSaveImage/Test/JPEGImages'

    with detection_graph.as_default():
        with tf.Session(graph=detection_graph) as sess:
            # while True:  # for image_path in TEST_IMAGE_PATHS:    #changed 20170825
            # Definite input and output Tensors for detection_graph
            image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
            # Each box represents a part of the image where a particular object was detected.
            detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
            # Each score represent how level of confidence for each of the objects.
            # Score is shown on the result image, together with the class label.
            detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
            detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
            num_detections = detection_graph.get_tensor_by_name('num_detections:0')
            for pidImage in glob.glob(PATH_TO_TEST_IMAGES_DIR + "/*.jpg"):
                TEST_IMAGE_PATHS = [os.path.join(PATH_TO_TEST_IMAGES_DIR, pidImage)]
                # Size, in inches, of the output images.
                IMAGE_SIZE = (12, 8)
                for image_path in TEST_IMAGE_PATHS:
                    image = Image.open(image_path)
                    # the array based representation of the image will be used later in order to prepare the
                    # result image with boxes and labels on it.
                    image_np = load_image_into_numpy_array(image)
                    # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
                    image_np_expanded = np.expand_dims(image_np, axis=0)
                    # Actual detection.
                    (boxes, scores, classes, num) = sess.run(
                        [detection_boxes, detection_scores, detection_classes, num_detections],
                        feed_dict={image_tensor: image_np_expanded})
                    # Visualization of the results of a detection.
                    # print(boxes)
                    DrawBox.visualize_boxes_and_labels_on_image_array(
                        image_np,
                        np.squeeze(boxes),
                        np.squeeze(classes).astype(np.int32),
                        np.squeeze(scores),
                        category_index,
                        max_boxes_to_draw=400,
                        use_normalized_coordinates=True,
                        groundtruth_box_visualization_color='red',
                        line_thickness=8)
                    #plt.figure(figsize=IMAGE_SIZE)
                    cv2.imwrite(SaveFilename + os.path.basename(image_path), image_np)
                    # plt.imshow(image_np)
                    cnt = cnt + 1
                    print(image_path)

 

posted @ 2018-06-13 10:36  唐淼  阅读(1456)  评论(0编辑  收藏  举报