Tensorflow Lite Model Maker --- 物体检测篇
前言
完成Tensorflow Lite Model Maker的物体检测,最终是为了在mediapipe上面执行,所以这篇文章记录自己的安装过程,中间遇到的问题,做个笔记,防止遗忘
环境配置
主要根据Train a salad detector with TFLite Model Maker进行安装,记住python最好安装3.6,会省去很多版本问题
本文在conda环境中执行
conda crate -n mediapipe_train python==3.6
conda activate mediapipe_train
安装相应的库
pip install tflite-model-maker
pip install pycocotools
pip install tflite-support
整体代码在Train a salad detector with TFLite Model Maker运行并没有什么比较大的问题,主要自己本地运行的时候,需要下载模型和数据集,这是需要重点解决的问题
代码详解
- 导入相应的包
import numpy as np
import os
from tflite_model_maker.config import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import object_detector
import tensorflow as tf
assert tf.__version__.startswith('2')
tf.get_logger().setLevel('ERROR')
from absl import logging
logging.set_verbosity(logging.ERROR)
2.加载模型:
spec = model_spec.get('efficientdet_lite2')
- 加载数据集:
train_data, validation_data, test_data = object_detector.DataLoader.from_csv('gs://cloud-ml-data/img/openimage/csv/salads_ml_use.csv')
如果在colab运行是没有问题的,但是如果在本地那么需要将数据集安装到自己本地,代码如下:
前提:安装gsutl,非常重要(安装如果需要其他依赖,大家可以自行查找,没有坑,我是时间长忘记了,就不写了)
pip install gsutil
gsutil cp gs://cloud-ml-data/img/openimage/csv/salads_ml_use.csv ./
现在salads_ml_use.csv已经在本地了,接下来执行下面代码,下面的dataset.csv就是最终的文件
import os
import pandas as pd
csv = pd.read_csv('./salads_ml_use.csv', header=None)
csv = csv.drop_duplicates(subset=[1])
for i in range(len(csv)):
print(csv.iat[i, 1])
url = csv.iat[i, 1]
command_line = 'gsutil cp ' + str(url) + ' ./imgs'
print(command_line)
os.system(command_line)
csv = pd.read_csv('./salads_ml_use.csv', header=None)
for i in range(len(csv)):
replace_path = csv.iat[i, 1].split('/')
new_path = 'imgs/' + replace_path[-1]
csv.iloc[i, 1] = new_path
csv.to_csv('./dataset.csv', header=None, index=None)
那么加载数据集的代码就变成了:
train_data, validation_data, test_data = object_detector.DataLoader.from_csv('./dataset.csv')
- 配置模型:
model = object_detector.create(train_data, model_spec=spec, batch_size=8, train_whole_model=True, validation_data=validation_data)
在本地应该是跑不起来,所以在这里关键是model_spec参数,我主要参考了该链接方法,关键在于两个文件,其中一个是:
- /home/***/anaconda3/envs/mediapipe_train/lib/python3.6/site-packages/tensorflow_examples/lite/model_maker/core/task/model_spec/init.py中,27-63给出了各种任务的模型名称,我使用了目标检测的efficientdet_lite2,在这里 可以 ad看到各种模型名称
- /home/***/anaconda3/envs/mediapipe_train/lib/python3.6/site-packages/tensorflow_examples/lite/model_maker/core/task/model_spec/object_detector_spec.py中,我需要把efficientdet_lite2_spec对应的url从 https://tfhub.dev/tensorflow/efficientdet/lite2/feature-vector/1改成’https://storage.googleapis.com/tfhub-modules/tensorflow/efficientdet/lite2/feature-vector/1.tar.gz’
通过上面两步,可以实现模型的配置
评估模型
之后就是模型的评估,以及将模型转化为lite之后再次评估
model.evaluate(test_data)
model.export(export_dir='.')
model.evaluate_tflite('model.tflite', test_data)
关于模型精度下降可以看该链接
可视化
可视化:(我已经整合好了,唯一要做的就是自己要找一张图片,在代码末尾有提示,自己改一下)
#@title Load the trained TFLite model and define some visualization functions
#@markdown This code comes from the TFLite Object D etection [Raspberry Pi sample](https://github.com/tensorflow/examples/tree/master/lite/examples/object_detection/raspberry_pi).
import platform
import json
import cv2
from typing import List, NamedTuple
from tflite_support import metadata
import tensorflow as tf
assert tf.__version__.startswith('2')
import numpy as np
from PIL import Image
Interpreter = tf.lite.Interpreter
load_delegate = tf.lite.experimental.load_delegate
# pylint: enable=g-import-not-at-top
class ObjectDetectorOptions(NamedTuple):
"""A config to initialize an object detector."""
enable_edgetpu: bool = False
"""Enable the model to run on EdgeTPU."""
label_allow_list: List[str] = None
"""The optional allow list of labels."""
label_deny_list: List[str] = None
"""The optional deny list of labels."""
max_results: int = -1
"""The maximum number of top-scored detection results to return."""
num_threads: int = 1
"""The number of CPU threads to be used."""
score_threshold: float = 0.0
"""The score threshold of detection results to return."""
class Rect(NamedTuple):
"""A rectangle in 2D space."""
left: float
top: float
right: float
bottom: float
class Category(NamedTuple):
"""A result of a classification task."""
label: str
score: float
index: int
class Detection(NamedTuple):
"""A detected object as the result of an ObjectDetector."""
bounding_box: Rect
categories: List[Category]
def edgetpu_lib_name():
"""Returns the library name of EdgeTPU in the current platform."""
return {
'Darwin': 'libedgetpu.1.dylib',
'Linux': 'libedgetpu.so.1',
'Windows': 'edgetpu.dll',
}.get(platform.system(), None)
class ObjectDetector:
"""A wrapper class for a TFLite object detection model."""
_OUTPUT_LOCATION_NAME = 'location'
_OUTPUT_CATEGORY_NAME = 'category'
_OUTPUT_SCORE_NAME = 'score'
_OUTPUT_NUMBER_NAME = 'number of detections'
def __init__(
self,
model_path: str,
options: ObjectDetectorOptions = ObjectDetectorOptions()
) -> None:
"""Initialize a TFLite object detection model.
Args:
model_path: Path to the TFLite model.
options: The config to initialize an object detector. (Optional)
Raises:
ValueError: If the TFLite model is invalid.
OSError: If the current OS isn't supported by EdgeTPU.
"""
# Load metadata from model.
displayer = metadata.MetadataDisplayer.with_model_file(model_path)
# Save model metadata for preprocessing later.
model_metadata = json.loads(displayer.get_metadata_json())
process_units = model_metadata['subgraph_metadata'][0]['input_tensor_metadata'][0]['process_units']
mean = 0.0
std = 1.0
for option in process_units:
if option['options_type'] == 'NormalizationOptions':
mean = option['options']['mean'][0]
std = option['options']['std'][0]
self._mean = mean
self._std = std
# Load label list from metadata.
file_name = displayer.get_packed_associated_file_list()[0]
label_map_file = displayer.get_associated_file_buffer(file_name).decode()
label_list = list(filter(lambda x: len(x) > 0, label_map_file.splitlines()))
self._label_list = label_list
# Initialize TFLite model.
if options.enable_edgetpu:
if edgetpu_lib_name() is None:
raise OSError("The current OS isn't supported by Coral EdgeTPU.")
interpreter = Interpreter(
model_path=model_path,
experimental_delegates=[load_delegate(edgetpu_lib_name())],
num_threads=options.num_threads)
else:
interpreter = Interpreter(
model_path=model_path, num_threads=options.num_threads)
interpreter.allocate_tensors()
input_detail = interpreter.get_input_details()[0]
# From TensorFlow 2.6, the order of the outputs become undefined.
# Therefore we need to sort the tensor indices of TFLite outputs and to know
# exactly the meaning of each output tensor. For example, if
# output indices are [601, 599, 598, 600], tensor names and indices aligned
# are:
# - location: 598
# - category: 599
# - score: 600
# - detection_count: 601
# because of the op's ports of TFLITE_DETECTION_POST_PROCESS
# (https://github.com/tensorflow/tensorflow/blob/a4fe268ea084e7d323133ed7b986e0ae259a2bc7/tensorflow/lite/kernels/detection_postprocess.cc#L47-L50).
sorted_output_indices = sorted(
[output['index'] for output in interpreter.get_output_details()])
self._output_indices = {
self._OUTPUT_LOCATION_NAME: sorted_output_indices[0],
self._OUTPUT_CATEGORY_NAME: sorted_output_indices[1],
self._OUTPUT_SCORE_NAME: sorted_output_indices[2],
self._OUTPUT_NUMBER_NAME: sorted_output_indices[3],
}
self._input_size = input_detail['shape'][2], input_detail['shape'][1]
self._is_quantized_input = input_detail['dtype'] == np.uint8
self._interpreter = interpreter
self._options = options
def detect(self, input_image: np.ndarray) -> List[Detection]:
"""Run detection on an input image.
Args:
input_image: A [height, width, 3] RGB image. Note that height and width
can be anything since the image will be immediately resized according
to the needs of the model within this function.
Returns:
A Person instance.
"""
image_height, image_width, _ = input_image.shape
input_tensor = self._preprocess(input_image)
self._set_input_tensor(input_tensor)
self._interpreter.invoke()
# Get all output details
boxes = self._get_output_tensor(self._OUTPUT_LOCATION_NAME)
classes = self._get_output_tensor(self._OUTPUT_CATEGORY_NAME)
scores = self._get_output_tensor(self._OUTPUT_SCORE_NAME)
count = int(self._get_output_tensor(self._OUTPUT_NUMBER_NAME))
return self._postprocess(boxes, classes, scores, count, image_width,
image_height)
def _preprocess(self, input_image: np.ndarray) -> np.ndarray:
"""Preprocess the input image as required by the TFLite model."""
# Resize the input
input_tensor = cv2.resize(input_image, self._input_size)
# Normalize the input if it's a float model (aka. not quantized)
if not self._is_quantized_input:
input_tensor = (np.float32(input_tensor) - self._mean) / self._std
# Add batch dimension
input_tensor = np.expand_dims(input_tensor, axis=0)
return input_tensor
def _set_input_tensor(self, image):
"""Sets the input tensor."""
tensor_index = self._interpreter.get_input_details()[0]['index']
input_tensor = self._interpreter.tensor(tensor_index)()[0]
input_tensor[:, :] = image
def _get_output_tensor(self, name):
"""Returns the output tensor at the given index."""
output_index = self._output_indices[name]
tensor = np.squeeze(self._interpreter.get_tensor(output_index))
return tensor
def _postprocess(self, boxes: np.ndarray, classes: np.ndarray,
scores: np.ndarray, count: int, image_width: int,
image_height: int) -> List[Detection]:
"""Post-process the output of TFLite model into a list of Detection objects.
Args:
boxes: Bounding boxes of detected objects from the TFLite model.
classes: Class index of the detected objects from the TFLite model.
scores: Confidence scores of the detected objects from the TFLite model.
count: Number of detected objects from the TFLite model.
image_width: Width of the input image.
image_height: Height of the input image.
Returns:
A list of Detection objects detected by the TFLite model.
"""
results = []
# Parse the model output into a list of Detection entities.
for i in range(count):
if scores[i] >= self._options.score_threshold:
y_min, x_min, y_max, x_max = boxes[i]
bounding_box = Rect(
top=int(y_min * image_height),
left=int(x_min * image_width),
bottom=int(y_max * image_height),
right=int(x_max * image_width))
class_id = int(classes[i])
category = Category(
score=scores[i],
label=self._label_list[class_id], # 0 is reserved for background
index=class_id)
result = Detection(bounding_box=bounding_box, categories=[category])
results.append(result)
# Sort detection results by score ascending
sorted_results = sorted(
results,
key=lambda detection: detection.categories[0].score,
reverse=True)
# Filter out detections in deny list
filtered_results = sorted_results
if self._options.label_deny_list is not None:
filtered_results = list(
filter(
lambda detection: detection.categories[0].label not in self.
_options.label_deny_list, filtered_results))
# Keep only detections in allow list
if self._options.label_allow_list is not None:
filtered_results = list(
filter(
lambda detection: detection.categories[0].label in self._options.
label_allow_list, filtered_results))
# Only return maximum of max_results detection.
if self._options.max_results > 0:
result_count = min(len(filtered_results), self._options.max_results)
filtered_results = filtered_results[:result_count]
return filtered_results
_MARGIN = 10 # pixels
_ROW_SIZE = 10 # pixels
_FONT_SIZE = 1
_FONT_THICKNESS = 1
_TEXT_COLOR = (0, 0, 255) # red
def visualize(
image: np.ndarray,
detections: List[Detection],
) -> np.ndarray:
"""Draws bounding boxes on the input image and return it.
Args:
image: The input RGB image.
detections: The list of all "Detection" entities to be visualize.
Returns:
Image with bounding boxes.
"""
for detection in detections:
# Draw bounding_box
start_point = detection.bounding_box.left, detection.bounding_box.top
end_point = detection.bounding_box.right, detection.bounding_box.bottom
cv2.rectangle(image, start_point, end_point, _TEXT_COLOR, 3)
# Draw label and score
category = detection.categories[0]
class_name = category.label
probability = round(category.score, 2)
result_text = class_name + ' (' + str(probability) + ')'
text_location = (_MARGIN + detection.bounding_box.left,
_MARGIN + _ROW_SIZE + detection.bounding_box.top)
cv2.putText(image, result_text, text_location, cv2.FONT_HERSHEY_PLAIN,
_FONT_SIZE, _TEXT_COLOR, _FONT_THICKNESS)
return image
#@title Run object detection and show the detection results
# INPUT_IMAGE_URL = "https://storage.googleapis.com/cloud-ml-data/img/openimage/3/2520/3916261642_0a504acd60_o.jpg" #@param {type:"string"}
DETECTION_THRESHOLD = 0.3 # @param {type:"number"}
TEMP_FILE = './tmp/image.jpg'
# !wget -q -O $TEMP_FILE $INPUT_IMAGE_URL
im = Image.open(TEMP_FILE)
im.thumbnail((512, 512), Image.ANTIALIAS)
image_np = np.asarray(im)
# Load the TFLite model
options = ObjectDetectorOptions(
num_threads=4,
score_threshold=DETECTION_THRESHOLD,
)
detector = ObjectDetector(model_path='model.tflite', options=options)
# Run object detection estimation using the model.
detections = detector.detect(image_np)
# Draw keypoints and edges on input image
image_np = visualize(image_np, detections)
# Show the detection result
image = Image.fromarray(image_np)
image.save('./tmp/image_det.jpg')