python调用tflite执行推理
- tensorflow2.x
- 参考了tensorflow文档。
- 文档的不足之处在于,没有强调resize_tensor_input的使用。实际上,在allocate_tensors之前,需要resize以确定输入tensor的shape,保持与输入数据一致。
import tensorflow as tf
import numpy as np
def run_inference(data):
model = "model.tflite"
interpreter = tf.lite.Interpreter(model_path=model)
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.resize_tensor_input(input_details[0]['index'], data.shape)
interpreter.allocate_tensors()
interpreter.set_tensor(input_details[0]['index'], data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
# output_data结果根据需要进行reshape
return output_data