tensorflow object detection预训练模型踩坑指南

记录使用tensorflow的物体检测预处理模型时所遇到的各种坑

python文件我使用的博客地址为:
Tensorflow object detection API训练自己的目标检测模型(检测图片中和视频中的物体)

由于原博主的tensorflow版本是tensorflow1.14.0版本,为了适配2.0版本人做了部分修改.
代码如下:

import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile

from collections import defaultdict
from io import StringIO

# from matplotlib import pyplot as plt
import matplotlib

from PIL import Image

# # This is needed to display the images.
# %matplotlib inline

# This is needed since the notebook is stored in the object_detection folder.
sys.path.append("..")

# from utils import label_map_util
# from utils import visualization_utils as vis_util
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util

# What model to download.
MODEL_NAME = 'other_library_files/ssd_mobilenet_v1_coco_2017_11_17'
# MODEL_FILE = MODEL_NAME + '.tar.gz'
# DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'

# Path to frozen detection graph. This is the actual model that is used for the object detection.
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'

# List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')
PATH_TO_LABELS = 'other_library_files/models-master/research/object_detection/data/mscoco_label_map.pbtxt'

NUM_CLASSES = 90

# download model
# opener = urllib.request.URLopener()
# 下载模型,如果已经下载好了下面这句代码可以注释掉
# opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)
# tar_file = tarfile.open(MODEL_FILE)
# for file in tar_file.getmembers():
#    file_name = os.path.basename(file.name)
#    if 'frozen_inference_graph.pb' in file_name:
#        tar_file.extract(file, os.getcwd())

# Load a (frozen) Tensorflow model into memory.
detection_graph = tf.Graph()
with detection_graph.as_default():
    od_graph_def = tf.compat.v1.GraphDef()
    with tf.io.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')
# # Loading label map
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES,
                                                             use_display_name=True)
category_index = label_map_util.create_category_index(categories)


# Helper code
def load_image_into_numpy_array(image):
    (im_width, im_height) = image.size
    return np.array(image.getdata()).reshape(
        (im_height, im_width, 3)).astype(np.uint8)


# For the sake of simplicity we will use only 2 images:
# image1.jpg
# image2.jpg
# If you want to test the code with your images, just add path to the images to the TEST_IMAGE_PATHS.
PATH_TO_TEST_IMAGES_DIR = 'other_library_files/models-master/research/object_detection/test_images' # 测试图片文件夹
TEST_IMAGE_PATHS = [os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 6)] # 遍历测试图片文件夹的图片

# Size, in inches, of the output images.
IMAGE_SIZE = (12, 8)

# 图像识别
with detection_graph.as_default():
    with tf.compat.v1.Session(graph=detection_graph) as sess:
        # Definite input and output Tensors for detection_graph
        image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
        # Each box represents a part of the image where a particular object was detected.
        detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
        # Each score represent how level of confidence for each of the objects.
        # Score is shown on the result image, together with the class label.
        detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
        detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
        num_detections = detection_graph.get_tensor_by_name('num_detections:0')

        # 图片依次识别
        for image_path in TEST_IMAGE_PATHS:
            image = Image.open(image_path)
            # the array based representation of the image will be used later in order to prepare the
            # result image with boxes and labels on it.
            # 稍后将使用基于数组的图像表示形式来准备带有框和标签的结果图像。
            image_np = load_image_into_numpy_array(image)
            # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
            # 由于模型期望图像具有形状,因此请扩展尺寸:[1 ,无,无,3]
            image_np_expanded = np.expand_dims(image_np, axis=0)
            image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
            # Each box represents a part of the image where a particular object was detected.
            # 每个框代表检测到特定对象的图像的一部分。
            boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
            # Each score represent how level of confidence for each of the objects.
            # Score is shown on the result image, together with the class label.
            scores = detection_graph.get_tensor_by_name('detection_scores:0')  # 识别正确分数
            classes = detection_graph.get_tensor_by_name('detection_classes:0')  # 识别种类
            num_detections = detection_graph.get_tensor_by_name('num_detections:0') # 识别个数
            # Actual detection.
            (boxes, scores, classes, num_detections) = sess.run(
                [boxes, scores, classes, num_detections],
                feed_dict={image_tensor: image_np_expanded})
            # Visualization of the results of a detection.
            # 可视化检测结果。对图片进行画框
            vis_util.visualize_boxes_and_labels_on_image_array(
                image_np,
                np.squeeze(boxes),
                np.squeeze(classes).astype(np.int32),
                np.squeeze(scores),
                category_index,
                use_normalized_coordinates=True,     # 使用显示标准化坐标
                line_thickness=8) # line_thickness 画线框的线条粗细

            # 用于显示
            matplotlib.use('TkAgg')
            matplotlib.pyplot.figure(figsize=IMAGE_SIZE)
            matplotlib.pyplot.imshow(image_np)
            matplotlib.pyplot.show()

所遇到的问题:

1.缺少object_dection模块

ModuleNotFoundError: No module named ‘object_detection’

在我们使用tensorflow的预处理模型时需要用到tensorflow的api,它包含了很多的工具方法,包括图像分类、检测、自然语言处理NLP、视频预测、图像理解等等,我们需要的对象检测API也包括在这里面。直接在github下载zip包,然后如果你的tensorflow是在Anaconda的一个开发环境中,就直接在anaconda的Lib的site-packages目录下(Anaconda本身只是一个包管理器)新建一个xx.pth的文件,比如新建一个tensorflow_model.pth文件,添加你的models的3个路径:

E:\models-master\research
E:\models-master\research\slim
E:\models-master\research\object_detection

如果你安装时是单独安装的python没有依靠anaconda,只需要在python的Lib的site-packages目录下按照上面的一样操作即可。

问题就解决了

2.无法导入string_int_label_map_pb2

具体错误为:

tensorflow object-detection ImportError: cannot import name 'string_int_label_map_pb2'

解决方法:
1.下载protoc-3.6.1-win32
2.解压后将bin里面的protoc.exe的路径加到电脑的环境变量的PATH中
3.打开cmd,在/model/research/目录下执行命令

protoc object_detection/protos/*.proto --python_out=.

发现出错

object_detection/protos/*.proto: No such file or directory

这是因为*.”在windows系统无法识别。这时就可以使用git命令,不要用CMD命令,当然这需要你Windows系统安装了git了,Git for Windows下载 安装完后,在/model/research/目录下使用git命令重试

protoc object_detection/protos/*.proto --python_out=. 

问题解决

3.tensorflow缺失GraphDef属性

具体错误为:

Error: module 'tensorflow' has no attribute 'GraphDef'

这是因为tensorflow的1和2的版本的差异

解决办法:

# 原来语句
graph_def = tf.GraphDef.FromString(file_handle.read())

修改为

# 修改后语句
graph_def = tf.compat.v1.GraphDef.FromString(file_handle.read())

4.tensorflow缺失gfile属性

具体错误为:

AttributeError: module 'tensorflow' has no attribute 'gfile'

运行如下代码

  if not tf.gfile.exists(DATA_DIRECTORY):
        tf.gfile.makedirs(DATA_DIRECTORY)
  with tf.gfile.GFile(filepath) as f:

会出现如下问题:

AttributeError: module ‘tensorflow’ has no attribute ‘gfile’

这是因为在当前的版本中,gfile已经定义在io包的file_io.py中。

解决办法:

所以只要改为下面的即可:

    if not tf.io.gfile.exists(DATA_DIRECTORY):
        tf.io.gfile.makedirs(DATA_DIRECTORY)
    with tf.io.gfile.GFile(filepath) as f:

5.tensorflow无法显示Matplotlib的ui界面

具体错误为:

UserWarning: Matplotlib is currently using agg, which is a non-GUI backend, so cannot show the figure.

这是因为Matplotlib默认中使用的Agg是一个没有图形显示界面的终端,常用的有图形界面显示的终端有TkAgg等。,所以我们选择更换为TKAgg
解决办法:

在程序中增加一条语句:

import matplotlib
matplotlib.use('TkAgg') (增加这条语句)

6.ERROR:无效的连续字节

具体错误为:

UnicodeDecodeError: ‘utf-8’ codec can’t decode byte 0xca in position 0: invalid continuation byte

博主的原因与网上其他造成这种错误的原因不同,博主是因为博主项目中的资源路径和所参考的文献中的资源路径不同,代码中的PATH_TO_LABELS的文本地址没有修改导致引用的文本不存在从而产生的错误

解决办法:
修改为正确的路径即可。

最后

所参考的所有文献地址如下:

posted @ 2021-01-24 23:46  DAmarkday  阅读(1995)  评论(0编辑  收藏  举报