Tensorflow 之物体检测

1)安装Protobuf
TensorFlow内部使用Protocol Buffers,物体检测需要特别安装一下。

Shell代码  收藏代码
  1. # yum info protobuf protobuf-compiler  
  2. 2.5.0 <-版本太低需要protobuf 2.6.1以上版本  
  3. # yum -y install autoconf automake libtool curl make g++ unzip  
  4. # cd /usr/local/src/  
  5. # wget https://github.com/google/protobuf/archive/v3.3.1.tar.gz -O protobuf-3.3.1.tar.gz  
  6. # tar -zxvf protobuf-3.3.1.tar.gz  
  7. # cd protobuf-3.3.1  
  8. # ./autogen.sh  
  9. # ./configure --prefix=/usr/local/protobuf  
  10. # make  
  11. # make install  
  12. # ldconfig  
  13. # export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/protobuf/lib  
  14. # export LIBRARY_PATH=$LIBRARY_PATH:/usr/local/protobuf/lib  
  15. # export PATH=$PATH:/usr/local/protobuf/bin  
  16. # protoc --version  
  17. libprotoc 3.3.1  



(2)配置Tensorflow物体检测API

Shell代码  收藏代码
  1. # source /usr/local/tensorflow2/bin/activate  
  2. # cd /usr/local/tensorflow2/tensorflow-models  



安装依赖包

Shell代码  收藏代码
  1. # pip install pillow  
  2. # pip install lxml  
  3. # pip install jupyter  
  4. # pip install matplotlib  



Protobuf编译

Shell代码  收藏代码
  1. # protoc object_detection/protos/*.proto --python_out=.  



设置环境变量

Shell代码  收藏代码
  1. # export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim  
  2. # ldconfig  



测试

Shell代码  收藏代码
  1. # python object_detection/builders/model_builder_test.py  


输出OK表示设置完成

(3)查看文档运行Demo
使用预训练模型来检测图像中的物体。官方提供了基于jupyter的教程。

Shell代码  收藏代码
  1. # source /usr/local/tensorflow2/bin/activate  
  2. # cd /usr/local/tensorflow2/tensorflow-models/object_detection/  
  3. # jupyter notebook --generate-config --allow-root  
  4. # python -c 'from notebook.auth import passwd;print(passwd())'  
  5. Enter password:123456  
  6. Verify password:123456  
  7. sha1:7d026454901a:009ae34a09296674d4a13521b80b8527999fd3da  
  8. # vi /root/.jupyter/jupyter_notebook_config.py  
  9. c.NotebookApp.password = 'sha1:7d026454901a:009ae34a09296674d4a13521b80b8527999fd3da'  
  10. # jupyter notebook --ip=127.0.0.1 --allow-root  



访问:http://127.0.0.1:8888/ 打开object_detection_tutorial.ipynb。
http://127.0.0.1:8888/notebooks/object_detection_tutorial.ipynb


默认是处理 object_detection/test_images 文件夹下的image1.jpg、image2.jpg,如果想识别其他图像可以把倒数第二个Cell的代码修改一下:

Python代码  收藏代码
  1. # TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ]  
  2. TEST_IMAGE_PATHS = ['<your image path>']  



在最后一个cell里添加2行代码:

Python代码  收藏代码
  1. plt.figure(figsize=IMAGE_SIZE)  
  2. plt.imshow(image_np)  


->

Python代码  收藏代码
  1. print(image_path.split('.')[0]+'_labeled.jpg') # Add  
  2. plt.figure(figsize=IMAGE_SIZE, dpi=300) # Modify  
  3. plt.imshow(image_np)  
  4. plt.savefig(image_path.split('.')[0] + '_labeled.jpg') # Add  




然后从头到尾挨个执行每个Cell后等结果。(Download Model那部分代码需要从网上下载文件比较慢!)


执行完成后在object_detection/test_images 文件夹下就能看到结果图了。
image1_labeled.jpg
image2_labeled.jpg



比较一下官方提供的检测结果图,可见和机器于很大关系:



(4)编码检测图像

从ImageNet中取一张图2008_004037.jpg测试,然后把 object_detection_tutorial.ipynb 里的代码改成可直接运行代码

Shell代码  收藏代码
  1. # vi object_detect_demo.py  
  2. # python object_detect_demo.py  



Python代码  收藏代码
  1. import numpy as np  
  2. import os  
  3. import six.moves.urllib as urllib  
  4. import sys  
  5. import tarfile  
  6. import tensorflow as tf  
  7. import zipfile  
  8. import matplotlib  
  9.   
  10. # Matplotlib chooses Xwindows backend by default.  
  11. matplotlib.use('Agg')  
  12.   
  13. from collections import defaultdict  
  14. from io import StringIO  
  15. from matplotlib import pyplot as plt  
  16. from PIL import Image  
  17. from utils import label_map_util  
  18. from utils import visualization_utils as vis_util  
  19.   
  20. ##################### Download Model  
  21. # What model to download.  
  22. MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017'  
  23. MODEL_FILE = MODEL_NAME + '.tar.gz'  
  24. DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'  
  25.   
  26. # Path to frozen detection graph. This is the actual model that is used for the object detection.  
  27. PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'  
  28.   
  29. # List of the strings that is used to add correct label for each box.  
  30. PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')  
  31.   
  32. NUM_CLASSES = 90  
  33.   
  34. # Download model if not already downloaded  
  35. if not os.path.exists(PATH_TO_CKPT):  
  36.     print('Downloading model... (This may take over 5 minutes)')  
  37.     opener = urllib.request.URLopener()  
  38.     opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)  
  39.     print('Extracting...')  
  40.     tar_file = tarfile.open(MODEL_FILE)  
  41.     for file in tar_file.getmembers():  
  42.         file_name = os.path.basename(file.name)  
  43.         if 'frozen_inference_graph.pb' in file_name:  
  44.             tar_file.extract(file, os.getcwd())  
  45. else:  
  46.     print('Model already downloaded.')  
  47.   
  48. ##################### Load a (frozen) Tensorflow model into memory.  
  49. print('Loading model...')  
  50. detection_graph = tf.Graph()  
  51.   
  52. with detection_graph.as_default():  
  53.     od_graph_def = tf.GraphDef()  
  54.     with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:  
  55.         serialized_graph = fid.read()  
  56.         od_graph_def.ParseFromString(serialized_graph)  
  57.         tf.import_graph_def(od_graph_def, name='')  
  58.   
  59. ##################### Loading label map  
  60. print('Loading label map...')  
  61. label_map = label_map_util.load_labelmap(PATH_TO_LABELS)  
  62. categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)  
  63. category_index = label_map_util.create_category_index(categories)  
  64.   
  65. ##################### Helper code  
  66. def load_image_into_numpy_array(image):  
  67.   (im_width, im_height) = image.size  
  68.   return np.array(image.getdata()).reshape(  
  69.       (im_height, im_width, 3)).astype(np.uint8)  
  70.   
  71. ##################### Detection  
  72. # Path to test image  
  73. TEST_IMAGE_PATH = 'test_images/2008_004037.jpg'  
  74.   
  75. # Size, in inches, of the output images.  
  76. IMAGE_SIZE = (12, 8)  
  77.   
  78. print('Detecting...')  
  79. with detection_graph.as_default():  
  80.   with tf.Session(graph=detection_graph) as sess:  
  81.     print(TEST_IMAGE_PATH)  
  82.     image = Image.open(TEST_IMAGE_PATH)  
  83.     image_np = load_image_into_numpy_array(image)  
  84.     image_np_expanded = np.expand_dims(image_np, axis=0)  
  85.     image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')  
  86.     boxes = detection_graph.get_tensor_by_name('detection_boxes:0')  
  87.     scores = detection_graph.get_tensor_by_name('detection_scores:0')  
  88.     classes = detection_graph.get_tensor_by_name('detection_classes:0')  
  89.     num_detections = detection_graph.get_tensor_by_name('num_detections:0')  
  90.     # Actual detection.  
  91.     (boxes, scores, classes, num_detections) = sess.run(  
  92.         [boxes, scores, classes, num_detections],  
  93.         feed_dict={image_tensor: image_np_expanded})  
  94.     # Print the results of a detection.  
  95.     print(scores)  
  96.     print(classes)  
  97.     print(category_index)  
  98.     # Visualization of the results of a detection.  
  99.     vis_util.visualize_boxes_and_labels_on_image_array(  
  100.         image_np,  
  101.         np.squeeze(boxes),  
  102.         np.squeeze(classes).astype(np.int32),  
  103.         np.squeeze(scores),  
  104.         category_index,  
  105.         use_normalized_coordinates=True,  
  106.         line_thickness=8)  
  107.     print(TEST_IMAGE_PATH.split('.')[0]+'_labeled.jpg')  
  108.     plt.figure(figsize=IMAGE_SIZE, dpi=300)  
  109.     plt.imshow(image_np)  
  110.     plt.savefig(TEST_IMAGE_PATH.split('.')[0] + '_labeled.jpg')  



检测结果(scores、classes、category_index)如下:

引用
[[ 0.91731095  0.80875194  0.67557526  0.67192227  0.3568708   0.23992854
   0.21897335  0.21443138  0.17383011  0.15901341  0.15674619  0.1558814
   0.15265906  0.1489363   0.14805503  0.13470834  0.132047    0.12655555
   0.12086334  0.11752894  0.10897312  0.10791111  0.10386674  0.10181901
   0.09687284  0.09644313  0.0929096   0.09187065  0.08420605  0.08250966
   0.08131051  0.07928694  0.07632151  0.07570603  0.0749495   0.07267584
   0.07258119  0.07075463  0.06964011  0.06901822  0.06894562  0.06892171
   0.06805679  0.06769397  0.06536105  0.06501643  0.06417865  0.06416738
   0.06377003  0.0634084   0.06247949  0.06245064  0.06173467  0.06126672
   0.06037482  0.05930964  0.05813492  0.05751488  0.05747007  0.05746768
   0.05737954  0.05694786  0.05581251  0.05559204  0.05539726  0.054422
   0.05410738  0.05389332  0.05359224  0.05349119  0.05328105  0.05284562
   0.0527565   0.05231072  0.05224103  0.05190464  0.05123441  0.05110639
   0.05002856  0.04982324  0.04956287  0.04943769  0.04906119  0.04891028
   0.04835404  0.04812568  0.0470486   0.04596276  0.04592303  0.04565331
   0.04564101  0.04550403  0.04531116  0.04507401  0.04495776  0.04489629
   0.04475424  0.0447024   0.04434219  0.04395287]]
[[  1.   1.  44.  44.  44.  44.  44.  75.  44.  44.  44.  82.  44.  88.
   79.  44.  44.  44.  88.  44.  88.  79.  44.  82.   1.  47.  88.  67.
   44.  70.  47.  79.  67.  67.  67.  67.  79.  72.  47.   1.  44.  44.
   44.   1.  67.  75.  72.  62.   1.   1.  44.  82.  79.  47.  79.  67.
   44.   1.  51.  75.  79.  51.  79.  62.  67.  44.  82.  82.  79.  82.
   79.  75.  72.  82.   1.   1.  46.  88.  82.  82.  82.  44.  67.  62.
   82.  79.  62.   1.  67.   1.  82.   1.  67.   1.  44.  88.  79.  51.
   44.  82.]]
{1: {'id': 1, 'name': u'person'}, 2: {'id': 2, 'name': u'bicycle'}, 3: {'id': 3, 'name': u'car'}, 4: {'id': 4, 'name': u'motorcycle'}, 5: {'id': 5, 'name': u'airplane'}, 6: {'id': 6, 'name': u'bus'}, 7: {'id': 7, 'name': u'train'}, 8: {'id': 8, 'name': u'truck'}, 9: {'id': 9, 'name': u'boat'}, 10: {'id': 10, 'name': u'traffic light'}, 11: {'id': 11, 'name': u'fire hydrant'}, 13: {'id': 13, 'name': u'stop sign'}, 14: {'id': 14, 'name': u'parking meter'}, 15: {'id': 15, 'name': u'bench'}, 16: {'id': 16, 'name': u'bird'}, 17: {'id': 17, 'name': u'cat'}, 18: {'id': 18, 'name': u'dog'}, 19: {'id': 19, 'name': u'horse'}, 20: {'id': 20, 'name': u'sheep'}, 21: {'id': 21, 'name': u'cow'}, 22: {'id': 22, 'name': u'elephant'}, 23: {'id': 23, 'name': u'bear'}, 24: {'id': 24, 'name': u'zebra'}, 25: {'id': 25, 'name': u'giraffe'}, 27: {'id': 27, 'name': u'backpack'}, 28: {'id': 28, 'name': u'umbrella'}, 31: {'id': 31, 'name': u'handbag'}, 32: {'id': 32, 'name': u'tie'}, 33: {'id': 33, 'name': u'suitcase'}, 34: {'id': 34, 'name': u'frisbee'}, 35: {'id': 35, 'name': u'skis'}, 36: {'id': 36, 'name': u'snowboard'}, 37: {'id': 37, 'name': u'sports ball'}, 38: {'id': 38, 'name': u'kite'}, 39: {'id': 39, 'name': u'baseball bat'}, 40: {'id': 40, 'name': u'baseball glove'}, 41: {'id': 41, 'name': u'skateboard'}, 42: {'id': 42, 'name': u'surfboard'}, 43: {'id': 43, 'name': u'tennis racket'}, 44: {'id': 44, 'name': u'bottle'}, 46: {'id': 46, 'name': u'wine glass'}, 47: {'id': 47, 'name': u'cup'}, 48: {'id': 48, 'name': u'fork'}, 49: {'id': 49, 'name': u'knife'}, 50: {'id': 50, 'name': u'spoon'}, 51: {'id': 51, 'name': u'bowl'}, 52: {'id': 52, 'name': u'banana'}, 53: {'id': 53, 'name': u'apple'}, 54: {'id': 54, 'name': u'sandwich'}, 55: {'id': 55, 'name': u'orange'}, 56: {'id': 56, 'name': u'broccoli'}, 57: {'id': 57, 'name': u'carrot'}, 58: {'id': 58, 'name': u'hot dog'}, 59: {'id': 59, 'name': u'pizza'}, 60: {'id': 60, 'name': u'donut'}, 61: {'id': 61, 'name': u'cake'}, 62: {'id': 62, 'name': u'chair'}, 63: {'id': 63, 'name': u'couch'}, 64: {'id': 64, 'name': u'potted plant'}, 65: {'id': 65, 'name': u'bed'}, 67: {'id': 67, 'name': u'dining table'}, 70: {'id': 70, 'name': u'toilet'}, 72: {'id': 72, 'name': u'tv'}, 73: {'id': 73, 'name': u'laptop'}, 74: {'id': 74, 'name': u'mouse'}, 75: {'id': 75, 'name': u'remote'}, 76: {'id': 76, 'name': u'keyboard'}, 77: {'id': 77, 'name': u'cell phone'}, 78: {'id': 78, 'name': u'microwave'}, 79: {'id': 79, 'name': u'oven'}, 80: {'id': 80, 'name': u'toaster'}, 81: {'id': 81, 'name': u'sink'}, 82: {'id': 82, 'name': u'refrigerator'}, 84: {'id': 84, 'name': u'book'}, 85: {'id': 85, 'name': u'clock'}, 86: {'id': 86, 'name': u'vase'}, 87: {'id': 87, 'name': u'scissors'}, 88: {'id': 88, 'name': u'teddy bear'}, 89: {'id': 89, 'name': u'hair drier'}, 90: {'id': 90, 'name': u'toothbrush'}}



获取前四个高于50%的物体结果如下:

引用
scores - 0.91731095  0.80875194  0.67557526  0.67192227
classes - 1.   1.  44.  44.
category_index - 1: {'id': 1, 'name': u'person'} 44: {'id': 44, 'name': u'bottle'}


图里也标出了【91%person、80%person、67%bottle、67%bottle】这四个物体:


(4)本地运行

1)生成 TFRecord
将jpg图片数据转换成TFRecord数据。

Shell代码  收藏代码
  1. # cd /usr/local/tensorflow2/tensorflow-models/object_detection  
  2. # wget http://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz  
  3. # wget http://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz  
  4. # tar -zxvf annotations.tar.gz  
  5. # tar -zxvf images.tar.gz  
  6. # python create_pet_tf_record.py --data_dir=`pwd` --output_dir=`pwd`  


images里全是已经标记好的jpg图片。执行完成后,会在当前目录下生成2个文件:pet_train.record、pet_val.record。

2)配置pipeline
在object_detection/samples下有各种模型的通道配置,复制一份出来用。

Shell代码  收藏代码
  1. # wget http://download.tensorflow.org/models/object_detection/faster_rcnn_resnet101_coco_11_06_2017.tar.gz  
  2. # tar -zxvf faster_rcnn_resnet101_coco_11_06_2017.tar.gz  
  3. # cp samples/configs/faster_rcnn_resnet101_pets.config mypet.config  
  4. # vi mypet.config  


修改PATH_TO_BE_CONFIGURED部分如下:

引用
fine_tune_checkpoint: "/usr/local/tensorflow2/tensorflow-models/object_detection/faster_rcnn_resnet101_coco_11_06_2017/model.ckpt"
from_detection_checkpoint: true

train_input_reader: {
  tf_record_input_reader {
    input_path: "/usr/local/tensorflow2/tensorflow-models/object_detection/pet_train.record"
  }
  label_map_path: "/usr/local/tensorflow2/tensorflow-models/object_detection/data/pet_label_map.pbtxt"
}

eval_input_reader: {
  tf_record_input_reader {
    input_path: "/usr/local/tensorflow2/tensorflow-models/object_detection/pet_val.record"
  }
  label_map_path: "/usr/local/tensorflow2/tensorflow-models/object_detection/data/pet_label_map.pbtxt"
}


from_detection_checkpoint设置为true,fine_tune_checkpoint需要设置检查点的路径。采用别人训练出来的checkpoint可以减少训练时间。
检查点的下载地址参考:
https://github.com/tensorflow/models/blob/master/object_detection/g3doc/detection_model_zoo.md

3)训练评估

Shell代码  收藏代码
  1. # mkdir -p /usr/local/tensorflow2/tensorflow-models/object_detection/model/train  
  2. # mkdir -p /usr/local/tensorflow2/tensorflow-models/object_detection/model/eval  



-- 训练 --

Shell代码  收藏代码
  1. # python object_detection/train.py \  
  2.      --logtostderr \  
  3.      --pipeline_config_path='/usr/local/tensorflow2/tensorflow-models/object_detection/mypet.config' \  
  4.      --train_dir='/usr/local/tensorflow2/tensorflow-models/object_detection/model/train'  


引用
INFO:tensorflow:Starting Session.
INFO:tensorflow:Saving checkpoint to path /usr/local/tensorflow2/tensorflow-models/object_detection/model/train/model.ckpt
INFO:tensorflow:Starting Queues.
INFO:tensorflow:global_step/sec: 0
INFO:tensorflow:Recording summary at step 0.



-- 评估 --

Shell代码  收藏代码
  1. # python object_detection/eval.py \  
  2.     --logtostderr \  
  3.     --pipeline_config_path='/usr/local/tensorflow2/tensorflow-models/object_detection/mypet.config' \  
  4.     --checkpoint_dir='/usr/local/tensorflow2/tensorflow-models/object_detection/model/train' \  
  5.     --eval_dir='/usr/local/tensorflow2/tensorflow-models/object_detection/model/eval'  


eval文件夹下会生成以下文件,一个文件对应一个image:
events.out.tfevents.1499152949.localhost.localdomain
events.out.tfevents.1499152964.localhost.localdomain
events.out.tfevents.1499152980.localhost.localdomain

-- 查看结果 --

Shell代码  收藏代码
  1. # tensorboard --logdir=/usr/local/tensorflow/tensorflow-models/object_detection/model/  



*** train和eval执行后直到终止命令前一直运行
*** 训练、评估、查看可以开3个终端分别同时运行

6月20号之前下载的tensorflow-models-master.zip是兼容Python3的会有很多问题:
https://github.com/tensorflow/models/issues/1597
https://github.com/tensorflow/models/pull/1614/files
比如:

引用
Traceback (most recent call last):
  File "create_pet_tf_record.py", line 213, in <module>
    tf.app.run()
  File "/usr/local/tensorflow/lib/python3.6/site-packages/tensorflow/python/platform/app.py", line 48, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "create_pet_tf_record.py", line 208, in main
    image_dir, train_examples)
  File "create_pet_tf_record.py", line 177, in create_tf_record
    tf_example = dict_to_tf_example(data, label_map_dict, image_dir)
  File "create_pet_tf_record.py", line 131, in dict_to_tf_example
    'image/filename': dataset_util.bytes_feature(data['filename']),
  File "/usr/local/tensorflow/tensorflow-models/object_detection/utils/dataset_util.py", line 30, in bytes_feature
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
TypeError: 'leonberger_185.jpg' has type str, but expected one of: bytes



引用
Traceback (most recent call last):
  File "object_detection/train.py", line 198, in <module>
    tf.app.run()
  File "/usr/local/tensorflow/lib/python3.6/site-packages/tensorflow/python/platform/app.py", line 48, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "object_detection/train.py", line 194, in main
    worker_job_name, is_chief, FLAGS.train_dir)
  File "/usr/local/tensorflow/tensorflow-models/object_detection/trainer.py", line 184, in train
    data_augmentation_options)
  File "/usr/local/tensorflow/tensorflow-models/object_detection/trainer.py", line 77, in _create_input_queue
    prefetch_queue_capacity=prefetch_queue_capacity)
  File "/usr/local/tensorflow/tensorflow-models/object_detection/core/batcher.py", line 81, in __init__
    {key: tensor.get_shape() for key, tensor in tensor_dict.iteritems()})
AttributeError: 'dict' object has no attribute 'iteritems'



引用
Traceback (most recent call last):
  File "object_detection/train.py", line 198, in <module>
    tf.app.run()
  File "/usr/local/tensorflow/lib/python3.6/site-packages/tensorflow/python/platform/app.py", line 48, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "object_detection/train.py", line 194, in main
    worker_job_name, is_chief, FLAGS.train_dir)
  File "/usr/local/tensorflow/tensorflow-models/object_detection/trainer.py", line 184, in train
    data_augmentation_options)
  File "/usr/local/tensorflow/tensorflow-models/object_detection/trainer.py", line 77, in _create_input_queue
    prefetch_queue_capacity=prefetch_queue_capacity)
  File "/usr/local/tensorflow/tensorflow-models/object_detection/core/batcher.py", line 93, in __init__
    num_threads=num_batch_queue_threads)
  File "/usr/local/tensorflow/lib/python3.6/site-packages/tensorflow/python/training/input.py", line 919, in batch
    name=name)
  File "/usr/local/tensorflow/lib/python3.6/site-packages/tensorflow/python/training/input.py", line 697, in _batch
    tensor_list = _as_tensor_list(tensors)
  File "/usr/local/tensorflow/lib/python3.6/site-packages/tensorflow/python/training/input.py", line 385, in _as_tensor_list
    return [tensors[k] for k in sorted(tensors)]
TypeError: '<' not supported between instances of 'tuple' and 'str'


等等

posted on 2018-03-06 11:57  jujua  阅读(4904)  评论(0编辑  收藏  举报