TensorRT推理加速-基于Tensorflow(keras)的uff格式模型(文件准备)
一、引子//Windows
tf(keras)训练好了模型,想要用Nvidia-TensorRT来重构训练好的模型为TRT推理引擎加快推理的速度。
二、准备文件
1、训练好模型以后(keras)可以通过以下方式保存keras模型为h5文件
1 | tf.keras.models.save_model(model, 'keras_model\\classify.h5' ) |
2、再通过以下代码来将h5文件转化为pb文件
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 | import tensorflow.compat.v1 as tf1 tf1.reset_default_graph() tf1.keras.backend.set_learning_phase( 0 ) # 调用模型前一定要执行该命令 tf1.disable_v2_behavior() # 禁止tensorflow2.0的行为 # 加载hdf5模型 hdf5_pb_model = tf1.keras.models.load_model( 'keras_model\\classify.h5' ) def freeze_session(session, keep_var_names = None , output_names = None , clear_devices = True ): graph = session.graph with graph.as_default(): # freeze_var_names = list(set(v.op.name for v in tf1.global_variables()).difference(keep_var_names or [])) output_names = output_names or [] # output_names += [v.op.name for v in tf1.global_variables()] print ( "output_names" , output_names) input_graph_def = graph.as_graph_def() # for node in input_graph_def.node: # print('node:', node.name) print ( "len node1" , len (input_graph_def.node)) if clear_devices: for node in input_graph_def.node: node.device = "" frozen_graph = tf1.graph_util.convert_variables_to_constants(session, input_graph_def, output_names) outgraph = tf1.graph_util.remove_training_nodes(frozen_graph) # 云掉与推理无关的内容 print ( "##################################################################" ) for node in outgraph.node: print ( 'node:' , node.name) print ( "len node1" , len (outgraph.node)) return outgraph output_folder2 = 'keras_model' frozen_graph = freeze_session(tf1.compat.v1.keras.backend.get_session(), output_names = [out.op.name for out in hdf5_pb_model.outputs]) tf1.train.write_graph(frozen_graph, output_folder2, "classify.pb" , as_text = False ) |
3、注意:以上代码基于tf2.0运行
4、pb模型文件转化为uff模型文件(tensorrt解析tf模型只能用uff格式)
首先,先安装TensorRT自带的(两个文件就在trt文件夹里面,cd到路径)
1 2 | pip install uff - 0.6 . 5 - py2.py3 - none - any .whl pip install graphsurgeon - 0.4 . 1 - py2.py3 - none - any .whl |
5、执行(cd到路径,执行以下过程需要tf1.x版本,否则报错,没有Graphdef)
转换
1 | convert - to - uff xxxx.pb - o xxxx.uff |
查看模型信息
1 | convert - to - uff xxxx.uff - l |
参考:
1 | 【Tensorflow2. 0 】 8 、tensorflow2. 0_hdf5_savedmodel_pb 模型转换 |
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步