TensorRT推理加速-基于Tensorflow(keras)的uff格式模型(文件准备)
一、引子//Windows
tf(keras)训练好了模型,想要用Nvidia-TensorRT来重构训练好的模型为TRT推理引擎加快推理的速度。
二、准备文件
1、训练好模型以后(keras)可以通过以下方式保存keras模型为h5文件
tf.keras.models.save_model(model, 'keras_model\\classify.h5')
2、再通过以下代码来将h5文件转化为pb文件
import tensorflow.compat.v1 as tf1 tf1.reset_default_graph() tf1.keras.backend.set_learning_phase(0) # 调用模型前一定要执行该命令 tf1.disable_v2_behavior() # 禁止tensorflow2.0的行为 # 加载hdf5模型 hdf5_pb_model = tf1.keras.models.load_model('keras_model\\classify.h5') def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True): graph = session.graph with graph.as_default(): # freeze_var_names = list(set(v.op.name for v in tf1.global_variables()).difference(keep_var_names or [])) output_names = output_names or [] # output_names += [v.op.name for v in tf1.global_variables()] print("output_names", output_names) input_graph_def = graph.as_graph_def() # for node in input_graph_def.node: # print('node:', node.name) print("len node1", len(input_graph_def.node)) if clear_devices: for node in input_graph_def.node: node.device = "" frozen_graph = tf1.graph_util.convert_variables_to_constants(session, input_graph_def, output_names) outgraph = tf1.graph_util.remove_training_nodes(frozen_graph) # 云掉与推理无关的内容 print("##################################################################") for node in outgraph.node: print('node:', node.name) print("len node1", len(outgraph.node)) return outgraph output_folder2 = 'keras_model' frozen_graph = freeze_session(tf1.compat.v1.keras.backend.get_session(), output_names=[out.op.name for out in hdf5_pb_model.outputs]) tf1.train.write_graph(frozen_graph, output_folder2, "classify.pb", as_text=False)
3、注意:以上代码基于tf2.0运行
4、pb模型文件转化为uff模型文件(tensorrt解析tf模型只能用uff格式)
首先,先安装TensorRT自带的(两个文件就在trt文件夹里面,cd到路径)
pip install uff-0.6.5-py2.py3-none-any.whl pip install graphsurgeon-0.4.1-py2.py3-none-any.whl
5、执行(cd到路径,执行以下过程需要tf1.x版本,否则报错,没有Graphdef)
转换
convert-to-uff xxxx.pb -o xxxx.uff
查看模型信息
convert-to-uff xxxx.uff -l
参考:
【Tensorflow2.0】8、tensorflow2.0_hdf5_savedmodel_pb模型转换