ckpt,pb,tflite使用和转换
ckpt,pb,tflite转换
一、ckpt,pb,tflite文件及其特点
ckpt模型文件
ckpt是tensorflow的默认的模型保存读取文件,包含四个部分:
- checkpoint
- model.ckpt.meta
- model.ckpt.index
- model.ckpt.data*
是结构权重数据分离的四个文件,其中
checkpoint:
记录模型目录下所有模型的文件列表
*ckpt.meta:
保存tensorflow计算图的网络结构
*ckpt.index:
保存了当前参数名
*ckpt.data:
保存了当前参数值
pb模型文件
pb模型是graph_def
的序列化文件,固化参数,只能用来做前向预测。(虽然如此,也能很容易的获得模型结构,重新复现也会容易很多)
tflite文件
tf-lite主要是针对移动端进行优化的平台,重新定义了移动端的核心算子,也提供了硬件加速的接口,拥有新的优化解释器。
二、模型保存和恢复
ckpt模型保存与恢复
# 参数恢复
saver_restore = tf.train.Saver([var for var in tf.trainable_variables()])
saver_restore.restore(sess, ckpt.model_checkpoint_path)
# 参数保存
saver = tf.train.Saver(max_to_keep=10)
saver.save(sess, "model.ckpt")
pb模型加载
通过tensor_name
获取节点:get_tensor_by_name()
# 读文件到graph_def
with tf.gfile.GFile(pb_path, 'rb') as fgraph:
graph_def = tf.GraphDef()
graph_def.ParseFromString(fgraph.read())
# print(graph_def)
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name='') # 把graph_def 加载到default_graph
# 使用get_tensor_by_name获取tensor
input_tensor = graph.get_tensor_by_name('VIDEOSR/Slice:0')
output_tensor = graph.get_tensor_by_name('%s:0' % out_node_name)
# 使用sess.run执行
image_out = sess.run(output_tensor, feed_dict={input_tensor: image_in})
...
tf-lite模型加载
通过index
获取节点:set_tensor(),get_tensor()
def run_example_single(model_path,input_image,feature2,feature1):
# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=model_path) # "model/save/converted_model.tflite"
interpreter.allocate_tensors()
# get input output info
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(input_details)
print(output_details)
# inputs index
index_inImg = input_details[0]['index']
# outputs index
index_outImg = output_details[0]['index']
# set inputs
interpreter.set_tensor(index_inImg, input_image)
# invoke
interpreter.invoke()
# get results
outImg = interpreter.get_tensor(index_outImg)
return outImg
三、ckpt,pb,tf-lite之间的转换
ckpt转pb
ckpt转pb是模型的持久化,固化参数的结果,一般只做前向。可以参考官方代码``
流程:
- 加载ckpt模型
- 将图使用
tf.train.write_graph()
写出 - 使用
freeze_graph.freeze_graph()
把模型参数固化保存
import tensorflow as tf
import os
import slim.nets.mobilenet_v1 as mobilenet_v1
import tensorflow.contrib.slim as slim
from tensorflow.python.tools import freeze_graph
def export_eval_pbtxt(MODEL_SAVE_PATH):
"""Export eval.pbtxt."""
with tf.Graph().as_default() as g:
images = tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3], name='input')
# is_training=False会把BN层去掉
with slim.arg_scope(mobilenet_v1.mobilenet_v1_arg_scope(is_training=False, regularize_depthwise=True)):
_, _ = mobilenet_v1.mobilenet_v1(inputs=images, is_training=False, depth_multiplier=1.0, num_classes=7)
saver = tf.train.Saver(max_to_keep=5)
pb_dir = os.path.join(MODEL_SAVE_PATH, 'pb_model')
graph_file = os.path.join(MODEL_SAVE_PATH, 'pb_model', 'mobilenet_v1_eval.pbtxt')
checkpoint = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
frozen_model = os.path.join(MODEL_SAVE_PATH, 'pb_model', 'frozen_model.pb')
with tf.Session() as sess:
if checkpoint and checkpoint.model_checkpoint_path:
try:
saver.restore(sess, checkpoint.model_checkpoint_path)
print("Successfully loaded:", checkpoint.model_checkpoint_path)
except:
print("Error on loading old network weights")
else:
print("Could not find old network weights")
print('Learning Started!')
with open(graph_file, 'w') as f:
f.write(str(g.as_graph_def()))
freeze_graph.freeze_graph(graph_file,
'',
False,
checkpoint.model_checkpoint_path,
"MobilenetV1/Predictions/Softmax",
'save/restore_all',
'save/Const:0',
frozen_model,
True,
"")
pb模型转tflite模型
- 将pb模型加载
tf.lite.TFLiteConverter.from_frozen_graph()
- 对模型进行转换
converter.convert()
- 将转换 后的结果保存在文件
def pb_to_tflite(input_name, output_name):
graph_def_file = os.path.join(MODEL_SAVE_PATH, 'pb_model', 'frozen_model.pb')
input_arrays = [input_name]
output_arrays = [output_name]
converter = tf.lite.TFLiteConverter.from_frozen_graph(graph_def_file, input_arrays, output_arrays)
tflite_model = converter.convert()
tflite_file = os.path.join(MODEL_SAVE_PATH, 'tflite_model', 'converted_model.tflite')
open(tflite_file, "wb").write(tflite_model)