onnx 增删改查,修改节点,删除节点,修改input,output
一、onnx 的数据类型,共有16种
elem_type: 1 --> float32 elem_type: 2 --> uint8 elem_type: 3 --> int8 elem_type: 4 --> uint16 elem_type: 5 --> int16 elem_type: 6 --> int32 elem_type: 7 --> int64 elem_type: 8 --> string elem_type: 9 --> boolean elem_type: 10 --> float16 elem_type: 11 --> float64 elem_type: 12 --> uint32 elem_type: 14 --> uint64 elem_type: 15 --> complex128 elem_type: 16 --> bfloat16
二、onnx 节点打印
onnx_print.py: python onnx_print.py model.onnx
import onnx import numpy as np import sys ori_file=sys.argv[1] onnx_model = onnx.load(ori_file) #print(onnx_model) #print(onnx_model.output) graph = onnx_model.graph print("model input: ") print(graph.input) for input_node in onnx_model.graph.input: print("input data name",input_node.name) print("\nmodel output: ") print(graph.output) for output_node in onnx_model.graph.output: print("output data name",output_node.name) print("\nall model graph") print("all_node_number: " + str(len(graph.node))) #节点个数 num = 0 for each in graph.node: print(num,"node") print(each) num +=1
三、onnx 输出信息表示
向量模型举例..... 模型输入: model input: [name: "input_ids" #名称 type { tensor_type { elem_type: 1 #输入类型 shape { dim { dim_value: 1 #输入第一维度 } dim { dim_value: 512 #输入第二维度 } } } } ] model output: [name: "1329" #模型输出的节点名称,对应到onnx的node的output type { tensor_type { elem_type: 1 #输出类型 shape { dim { dim_value: 1#输的第一维度 } dim { dim_value: 768 #输出的第二维度 } } } } ] 节点: input: "/encoder/encoder/layer.11/output/LayerNorm/Pow_output_0" #输入节点名称,对应上层output output: "/encoder/encoder/layer.11/output/LayerNorm/ReduceMean_1_output_0" #输出节点名,对应下层input,或者模型 output name: "/encoder/encoder/layer.11/output/LayerNorm/ReduceMean_1" #当前节点名称 op_type: "ReduceMean" #当前节点的op 操作 attribute { #节点属性 name: "axes" type: INTS ints: -1 }
四、onnx 增,删,改:最下面有示例
import onnx import numpy as np import sys from onnx import helper ori_file=sys.argv[1] #原始模型 model.onnx onnxfile=sys.argv[2] #生成后的模型 model_modify.onnx onnx_model = onnx.load(ori_file) graph = onnx_model.graph # 创建新节点 #new_output = helper.make_tensor_value_info(new_output_name, # model.graph.output[len(model.graph.output) - 1].type.tensor_type.elem_type, # model.graph.output[len(model.graph.output) - 1].type.tensor_type.shape) modify_input_node=False modify_output_node=False delete_node=False add_node=False print("all_node_number: " + str(len(graph.node))) all_node_len = len(graph.node) node_index = all_node_len #修改input节点名称,以及类型; if(modify_input_node): for index,eachNode in enumerate(graph.input): now_name = graph.input[index].name new_input = helper.make_tensor_value_info(now_name, graph.input[index].type.tensor_type.elem_type, [1,256]) #修改模型输出维度 graph.input.remove(graph.input[index]) #删除旧节点, #graph.input.append(new_input) #插入新节点,顺序会乱 graph.input.insert(index,new_input) #插入新节点,可以保证之前的顺序 if(modify_output_node):#修改 output 名称或者维度,或者类型.... for index,eachNode in enumerate(graph.output): now_name = eachNode.name if(now_name == "1329"): new_output = helper.make_tensor_value_info("last_output_1", graph.output[index].type.tensor_type.elem_type, [1,768]) # 维度是list,不能是*tensor_type.shape graph.output.remove(graph.output[index]) #删除旧节点, #graph.input.append(new_input) #插入新节点,顺序会乱 graph.output.insert(index,new_output) #插入新节点,保证原始顺序 if( delete_node == True ): #删除节点 print("need_modify index : " + str(node_index)) num = 0 for i in range(0,10): #删除最后10层模型,也可指定index删除,但是遍历或许有问题 graph.node.remove(graph.node[-1]) last_node = graph.node[-1] print("\last_node: ",last_node) print("after_delete_node_number: " + str(len(graph.node))) if(add_node == True):#最后一层增加softmax,但是也需要修改output节点名称和维度 last_index = len(graph.node) old_node = graph.node[-1] print(old_node.output) new_node = onnx.helper.make_node( 'Softmax', name='Softmax_1234', inputs=old_node.output, outputs=["last_output_1"], #是列表,传入字符串报错,but单个字符串不报错 ) #graph.node.add(new_node) #这样操作会报错,TypeError: No positional arguments allowed graph.node.insert(last_index,new_node) # 最后一个节点插入 print("add : \n",graph.node[-1]) print("elem_type: ",graph.output[-1].type.tensor_type.elem_type) #打印elem_type ''' onnx.checker.check_model() 检查模型的一致性,即模型在结构、格式和配置方面的正确性和完整性。 model:要检查的模型。如果模型是一个路径,函数会首先检查模型路径。如果模型的字节大小超过2GB,应使用模型路径来调用该函数。 full_check:如果为 True,函数还会运行形状推断检查。 skip_opset_compatibility_check:如果为 True,函数将跳过算子集兼容性检查。 check_custom_domain:如果为 True,函数将检查所有域。否则,仅检查内置域。 ''' #onnx.checker.check_model(onnx_model) onnx.checker.check_model(onnx_model,full_check=True) onnx.save(onnx_model, onnxfile)
效果示例如下:右侧处理后结果
修改input维度:
修改output名称:
删除最后10个节点:仅仅展示效果,实际删除后或许还需要修改 outout名称为对应node output
最后一层添加个softmax:仅仅展示效果,实际删除后或许还需要修改 outout名称为对应node output