onnx模型转换
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 | import os import tensorflow as tf from functools import lru_cache from tensorflow.python.framework import importer from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 def freeze_keras_model2pb(keras_model, pb_filepath, input_variable_name_list = None , output_variable_name_list = None ): """ karas 模型转pb :param keras_model: 待转换模型 :param pb_filepath: 模型pb文件保存路径 :param input_variable_name_list: 输入变量名称列表 :param output_variable_name_list: 输出变量名称列表 :return: """ assert hasattr (keras_model, 'inputs' ), "the keras model must be built with functional api or sequential" # save pb if input_variable_name_list is None : input_variable_name_list = list () if output_variable_name_list is None : output_variable_name_list = list () if len (input_variable_name_list) = = len (keras_model.inputs): input_variable_list = input_variable_name_list else : input_variable_list = [ 'x%d' % i for i in range ( len (keras_model.inputs))] input_func_signature_list = [ tf.TensorSpec(item.shape, dtype = item.dtype, name = name) for name, item in zip (input_variable_list, keras_model.inputs)] full_model = tf.function( lambda * x: keras_model(x, training = False )) # To obtain an individual graph, use the get_concrete_function method of the callable created by tf.function. # It can be called with the same arguments as func and returns a special tf.Graph object concrete_func = full_model.get_concrete_function(input_func_signature_list) # Get frozen ConcreteFunction frozen_graph = convert_variables_to_constants_v2(concrete_func) graph_def = frozen_graph.graph.as_graph_def() out_idx = 0 for node in graph_def.node: node.device = "" if node.name.startswith( 'Identity' ): out_idx + = 1 if len (output_variable_name_list) = = out_idx: output_variable_list = output_variable_name_list else : output_variable_list = [ 'y%d' % i for i in range (out_idx)] out_idx = 0 for node in graph_def.node: node.device = "" if node.name.startswith( 'Identity' ): node.name = output_variable_list[out_idx] out_idx + = 1 new_graph = tf.Graph() with new_graph.as_default(): importer.import_graph_def(graph_def, name = "") return tf.io.write_graph(graph_or_graph_def = new_graph, logdir = os.path.dirname(pb_filepath), name = os.path.basename(pb_filepath), as_text = False ), input_variable_list, output_variable_list def wrap_frozen_graph(graph_def, inputs, outputs, print_graph = False ): graph = tf.Graph() def _imports_graph_def(): tf.graph_util.import_graph_def(graph_def, name = "") with graph.as_default(): wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, []) import_graph = wrapped_import.graph if print_graph: print ( "-" * 50 ) print ( "Frozen model layers: " ) layers = [op.name for op in import_graph.get_operations()] for layer in layers: print (layer) print ( "-" * 50 ) return wrapped_import.prune(tf.nest.map_structure(import_graph.as_graph_element, inputs), tf.nest.map_structure(import_graph.as_graph_element, outputs)) def pb_file_to_concrete_function(pb_file, inputs, outputs, print_graph = False ): """ pb_file 转 concrete function :param pb_file: :param inputs: :param outputs: :param print_graph: :return: """ with tf.io.gfile.GFile(pb_file, "rb" ) as f: graph_def = tf.compat.v1.GraphDef() graph_def.ParseFromString(f.read()) frozen_func = wrap_frozen_graph(graph_def = graph_def, inputs = inputs, outputs = outputs, print_graph = print_graph) return graph_def, frozen_func if __name__ = = "__main__" : # 需要事先安装tf2onnx # tf模型在转换onnx模型前需要先用tf.keras.Model包一层 ind_input, seg_input = tf.keras.layers. Input ([max_len]), tf.keras.layers. Input ([max_len]) your_keras_model = tf.keras.Model(inputs = [ind_input, seg_input], outputs = model([ind_input, seg_input])) # 保存onnx静态模型 _, input_vaiable_list, ouput_vaiable_list = freeze_keras_model2pb(your_keras_model, "your_keras_model.pb" ) # 导入onnx静态模型 graph, model_onnx = pbfile2concrete_function( "your_keras_model.pb" , input_vaiable_list, ouput_vaiable_list) # 开心的使用model_onnx吧 |
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· Vue3状态管理终极指南:Pinia保姆级教程