ckpt,pb,tflite使用和转换

ckpt,pb,tflite转换

一、ckpt,pb,tflite文件及其特点

ckpt模型文件

ckpt是tensorflow的默认的模型保存读取文件,包含四个部分:

  • checkpoint
  • model.ckpt.meta
  • model.ckpt.index
  • model.ckpt.data*

是结构权重数据分离的四个文件,其中

checkpoint:记录模型目录下所有模型的文件列表

*ckpt.meta:保存tensorflow计算图的网络结构

*ckpt.index:保存了当前参数名

*ckpt.data:保存了当前参数值

pb模型文件

pb模型是graph_def的序列化文件,固化参数,只能用来做前向预测。(虽然如此,也能很容易的获得模型结构,重新复现也会容易很多)

tflite文件

tf-lite主要是针对移动端进行优化的平台,重新定义了移动端的核心算子,也提供了硬件加速的接口,拥有新的优化解释器。

二、模型保存和恢复

ckpt模型保存与恢复

# 参数恢复
saver_restore = tf.train.Saver([var for var in tf.trainable_variables()])
saver_restore.restore(sess, ckpt.model_checkpoint_path)

# 参数保存
saver = tf.train.Saver(max_to_keep=10)
saver.save(sess, "model.ckpt")

pb模型加载

通过tensor_name获取节点:get_tensor_by_name()

# 读文件到graph_def
with tf.gfile.GFile(pb_path, 'rb') as fgraph:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(fgraph.read())
    # print(graph_def)
with tf.Graph().as_default() as graph:
    tf.import_graph_def(graph_def, name='') # 把graph_def 加载到default_graph
    
    # 使用get_tensor_by_name获取tensor
    input_tensor = graph.get_tensor_by_name('VIDEOSR/Slice:0')
    output_tensor = graph.get_tensor_by_name('%s:0' % out_node_name)
    
    # 使用sess.run执行
    image_out = sess.run(output_tensor, feed_dict={input_tensor: image_in})
    ...

tf-lite模型加载

通过index获取节点:set_tensor(),get_tensor()

def run_example_single(model_path,input_image,feature2,feature1):
    # Load TFLite model and allocate tensors.
    interpreter = tf.lite.Interpreter(model_path=model_path)    # "model/save/converted_model.tflite"
    interpreter.allocate_tensors()

    # get input output info
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    print(input_details)
    print(output_details)

    # inputs index
    index_inImg = input_details[0]['index']

    # outputs index
    index_outImg = output_details[0]['index']

    # set inputs
    interpreter.set_tensor(index_inImg, input_image)

    # invoke
    interpreter.invoke()

    # get results
    outImg = interpreter.get_tensor(index_outImg)

    return outImg

三、ckpt,pb,tf-lite之间的转换

ckpt转pb

ckpt转pb是模型的持久化,固化参数的结果,一般只做前向。可以参考官方代码``

流程:

  1. 加载ckpt模型
  2. 将图使用tf.train.write_graph()写出
  3. 使用freeze_graph.freeze_graph()把模型参数固化保存
import tensorflow as tf
import os
import slim.nets.mobilenet_v1 as mobilenet_v1
import tensorflow.contrib.slim as slim
from tensorflow.python.tools import freeze_graph
 
 
def export_eval_pbtxt(MODEL_SAVE_PATH):
    """Export eval.pbtxt."""
    with tf.Graph().as_default() as g:
        images = tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3], name='input')
        # is_training=False会把BN层去掉
        with slim.arg_scope(mobilenet_v1.mobilenet_v1_arg_scope(is_training=False, regularize_depthwise=True)):
            _, _ = mobilenet_v1.mobilenet_v1(inputs=images, is_training=False, depth_multiplier=1.0, num_classes=7)
 
        saver = tf.train.Saver(max_to_keep=5)
        pb_dir = os.path.join(MODEL_SAVE_PATH, 'pb_model')
 
        graph_file = os.path.join(MODEL_SAVE_PATH, 'pb_model', 'mobilenet_v1_eval.pbtxt')
        
        checkpoint = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
        frozen_model = os.path.join(MODEL_SAVE_PATH, 'pb_model', 'frozen_model.pb')
        
        with tf.Session() as sess:
            if checkpoint and checkpoint.model_checkpoint_path:
                try:
                    saver.restore(sess, checkpoint.model_checkpoint_path)
                    print("Successfully loaded:", checkpoint.model_checkpoint_path)
                except:
                    print("Error on loading old network weights")
            else:
                print("Could not find old network weights")
 
            print('Learning Started!')
            with open(graph_file, 'w') as f:
                f.write(str(g.as_graph_def()))
            freeze_graph.freeze_graph(graph_file,
                                      '',
                                      False,
                                      checkpoint.model_checkpoint_path,
                                      "MobilenetV1/Predictions/Softmax",
                                      'save/restore_all',
                                      'save/Const:0',
                                      frozen_model,
                                      True,
                                      "")

pb模型转tflite模型

  1. 将pb模型加载tf.lite.TFLiteConverter.from_frozen_graph()
  2. 对模型进行转换converter.convert()
  3. 将转换 后的结果保存在文件
def pb_to_tflite(input_name, output_name):
    graph_def_file = os.path.join(MODEL_SAVE_PATH, 'pb_model', 'frozen_model.pb')
    input_arrays = [input_name]
    output_arrays = [output_name]
 
    converter = tf.lite.TFLiteConverter.from_frozen_graph(graph_def_file, input_arrays, output_arrays)
    tflite_model = converter.convert()
    tflite_file = os.path.join(MODEL_SAVE_PATH, 'tflite_model', 'converted_model.tflite')
    open(tflite_file, "wb").write(tflite_model)
posted @ 2020-08-18 15:59  wioponsen  阅读(5489)  评论(0编辑  收藏  举报