tensorflow object detection预训练模型踩坑指南

记录使用tensorflow的物体检测预处理模型时所遇到的各种坑

python文件我使用的博客地址为:
Tensorflow object detection API训练自己的目标检测模型(检测图片中和视频中的物体)

由于原博主的tensorflow版本是tensorflow1.14.0版本,为了适配2.0版本人做了部分修改.
代码如下:

import numpy as np import os import six.moves.urllib as urllib import sys import tarfile import tensorflow as tf import zipfile from collections import defaultdict from io import StringIO # from matplotlib import pyplot as plt import matplotlib from PIL import Image # # This is needed to display the images. # %matplotlib inline # This is needed since the notebook is stored in the object_detection folder. sys.path.append("..") # from utils import label_map_util # from utils import visualization_utils as vis_util from object_detection.utils import label_map_util from object_detection.utils import visualization_utils as vis_util # What model to download. MODEL_NAME = 'other_library_files/ssd_mobilenet_v1_coco_2017_11_17' # MODEL_FILE = MODEL_NAME + '.tar.gz' # DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/' # Path to frozen detection graph. This is the actual model that is used for the object detection. PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb' # List of the strings that is used to add correct label for each box. PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt') PATH_TO_LABELS = 'other_library_files/models-master/research/object_detection/data/mscoco_label_map.pbtxt' NUM_CLASSES = 90 # download model # opener = urllib.request.URLopener() # 下载模型,如果已经下载好了下面这句代码可以注释掉 # opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE) # tar_file = tarfile.open(MODEL_FILE) # for file in tar_file.getmembers(): # file_name = os.path.basename(file.name) # if 'frozen_inference_graph.pb' in file_name: # tar_file.extract(file, os.getcwd()) # Load a (frozen) Tensorflow model into memory. detection_graph = tf.Graph() with detection_graph.as_default(): od_graph_def = tf.compat.v1.GraphDef() with tf.io.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name='') # # Loading label map label_map = label_map_util.load_labelmap(PATH_TO_LABELS) categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True) category_index = label_map_util.create_category_index(categories) # Helper code def load_image_into_numpy_array(image): (im_width, im_height) = image.size return np.array(image.getdata()).reshape( (im_height, im_width, 3)).astype(np.uint8) # For the sake of simplicity we will use only 2 images: # image1.jpg # image2.jpg # If you want to test the code with your images, just add path to the images to the TEST_IMAGE_PATHS. PATH_TO_TEST_IMAGES_DIR = 'other_library_files/models-master/research/object_detection/test_images' # 测试图片文件夹 TEST_IMAGE_PATHS = [os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 6)] # 遍历测试图片文件夹的图片 # Size, in inches, of the output images. IMAGE_SIZE = (12, 8) # 图像识别 with detection_graph.as_default(): with tf.compat.v1.Session(graph=detection_graph) as sess: # Definite input and output Tensors for detection_graph image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') # Each box represents a part of the image where a particular object was detected. detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0') # Each score represent how level of confidence for each of the objects. # Score is shown on the result image, together with the class label. detection_scores = detection_graph.get_tensor_by_name('detection_scores:0') detection_classes = detection_graph.get_tensor_by_name('detection_classes:0') num_detections = detection_graph.get_tensor_by_name('num_detections:0') # 图片依次识别 for image_path in TEST_IMAGE_PATHS: image = Image.open(image_path) # the array based representation of the image will be used later in order to prepare the # result image with boxes and labels on it. # 稍后将使用基于数组的图像表示形式来准备带有框和标签的结果图像。 image_np = load_image_into_numpy_array(image) # Expand dimensions since the model expects images to have shape: [1, None, None, 3] # 由于模型期望图像具有形状,因此请扩展尺寸:[1 ,无,无,3] image_np_expanded = np.expand_dims(image_np, axis=0) image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') # Each box represents a part of the image where a particular object was detected. # 每个框代表检测到特定对象的图像的一部分。 boxes = detection_graph.get_tensor_by_name('detection_boxes:0') # Each score represent how level of confidence for each of the objects. # Score is shown on the result image, together with the class label. scores = detection_graph.get_tensor_by_name('detection_scores:0') # 识别正确分数 classes = detection_graph.get_tensor_by_name('detection_classes:0') # 识别种类 num_detections = detection_graph.get_tensor_by_name('num_detections:0') # 识别个数 # Actual detection. (boxes, scores, classes, num_detections) = sess.run( [boxes, scores, classes, num_detections], feed_dict={image_tensor: image_np_expanded}) # Visualization of the results of a detection. # 可视化检测结果。对图片进行画框 vis_util.visualize_boxes_and_labels_on_image_array( image_np, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores), category_index, use_normalized_coordinates=True, # 使用显示标准化坐标 line_thickness=8) # line_thickness 画线框的线条粗细 # 用于显示 matplotlib.use('TkAgg') matplotlib.pyplot.figure(figsize=IMAGE_SIZE) matplotlib.pyplot.imshow(image_np) matplotlib.pyplot.show()

所遇到的问题:

1.缺少object_dection模块

ModuleNotFoundError: No module named ‘object_detection’

在我们使用tensorflow的预处理模型时需要用到tensorflow的api,它包含了很多的工具方法,包括图像分类、检测、自然语言处理NLP、视频预测、图像理解等等,我们需要的对象检测API也包括在这里面。直接在github下载zip包,然后如果你的tensorflow是在Anaconda的一个开发环境中,就直接在anaconda的Lib的site-packages目录下(Anaconda本身只是一个包管理器)新建一个xx.pth的文件,比如新建一个tensorflow_model.pth文件,添加你的models的3个路径:

E:\models-master\research
E:\models-master\research\slim
E:\models-master\research\object_detection

如果你安装时是单独安装的python没有依靠anaconda,只需要在python的Lib的site-packages目录下按照上面的一样操作即可。

问题就解决了

2.无法导入string_int_label_map_pb2

具体错误为:

tensorflow object-detection ImportError: cannot import name 'string_int_label_map_pb2'

解决方法:
1.下载protoc-3.6.1-win32
2.解压后将bin里面的protoc.exe的路径加到电脑的环境变量的PATH中
3.打开cmd,在/model/research/目录下执行命令

protoc object_detection/protos/*.proto --python_out=.

发现出错

object_detection/protos/*.proto: No such file or directory

这是因为*.”在windows系统无法识别。这时就可以使用git命令,不要用CMD命令,当然这需要你Windows系统安装了git了,Git for Windows下载 安装完后,在/model/research/目录下使用git命令重试

protoc object_detection/protos/*.proto --python_out=.

问题解决

3.tensorflow缺失GraphDef属性

具体错误为:

Error: module 'tensorflow' has no attribute 'GraphDef'

这是因为tensorflow的1和2的版本的差异

解决办法:

# 原来语句 graph_def = tf.GraphDef.FromString(file_handle.read())

修改为

# 修改后语句 graph_def = tf.compat.v1.GraphDef.FromString(file_handle.read())

4.tensorflow缺失gfile属性

具体错误为:

AttributeError: module 'tensorflow' has no attribute 'gfile'

运行如下代码

if not tf.gfile.exists(DATA_DIRECTORY): tf.gfile.makedirs(DATA_DIRECTORY) with tf.gfile.GFile(filepath) as f:

会出现如下问题:

AttributeError: module ‘tensorflow’ has no attribute ‘gfile’

这是因为在当前的版本中,gfile已经定义在io包的file_io.py中。

解决办法:

所以只要改为下面的即可:

if not tf.io.gfile.exists(DATA_DIRECTORY): tf.io.gfile.makedirs(DATA_DIRECTORY) with tf.io.gfile.GFile(filepath) as f:

5.tensorflow无法显示Matplotlib的ui界面

具体错误为:

UserWarning: Matplotlib is currently using agg, which is a non-GUI backend, so cannot show the figure.

这是因为Matplotlib默认中使用的Agg是一个没有图形显示界面的终端,常用的有图形界面显示的终端有TkAgg等。,所以我们选择更换为TKAgg
解决办法:

在程序中增加一条语句:

import matplotlib matplotlib.use('TkAgg') (增加这条语句)

6.ERROR:无效的连续字节

具体错误为:

UnicodeDecodeError: ‘utf-8’ codec can’t decode byte 0xca in position 0: invalid continuation byte

博主的原因与网上其他造成这种错误的原因不同,博主是因为博主项目中的资源路径和所参考的文献中的资源路径不同,代码中的PATH_TO_LABELS的文本地址没有修改导致引用的文本不存在从而产生的错误

解决办法:
修改为正确的路径即可。

最后

所参考的所有文献地址如下:


__EOF__

本文作者damarkday知识库
本文链接https://www.cnblogs.com/GoodMemoryBlog/p/14323015.html
关于博主:评论和私信会在第一时间回复。或者直接私信我。
版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!
声援博主:如果您觉得文章对您有帮助,可以点击文章右下角推荐一下。您的鼓励是博主的最大动力!
posted @   DAmarkday  阅读(2081)  评论(0编辑  收藏  举报
编辑推荐:
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
· 一个奇形怪状的面试题:Bean中的CHM要不要加volatile?
阅读排行:
· 分享4款.NET开源、免费、实用的商城系统
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· 上周热点回顾(2.24-3.2)
点击右上角即可分享
微信分享提示