ONNX结构及构建抽取子图
ONNX结构
https://mmdeploy.readthedocs.io/zh-cn/latest/tutorial/05_onnx_model_editing.html
https://blog.csdn.net/u013597931/article/details/84401047
https://zhuanlan.zhihu.com/p/346511883
https://zhuanlan.zhihu.com/p/353613079
https://zhuanlan.zhihu.com/p/350702340
https://zhuanlan.zhihu.com/p/350833729
ONNX抽取子图
https://github.com/onnx/onnx/blob/main/docs/PythonAPIOverview.md
https://blog.csdn.net/u013597931/article/details/84635779
import onnx
input_path = "path/to/the/original/model.onnx"
output_path = "path/to/save/the/extracted/model.onnx"
input_names = ["input_0", "input_1", "input_2"]
output_names = ["output_0", "output_1"]
onnx.utils.extract_model(input_path, output_path, input_names, output_names)
ONNX 获取每层节点名以及每层输入
import os
import glob
import subprocess
import numpy as np
import onnxruntime as ort
import onnx
import sys
import shutil
import copy
import json
from collections import OrderedDict
from utilities.utils import *
if __name__ == "__main__":
model = onnx.load("./tmp/deeplabv3.onnx")
# 模型推理
ori_output = copy.deepcopy(model.graph.output)
# 输出模型每层的输出
for node in model.graph.node:
for output in node.output:
if output not in ori_output:
model.graph.output.extend([onnx.ValueInfoProto(name=output)])
#进行配置
if ort.get_device()=="CPU":
config = ort.SessionOptions()
cpu_num_thread=4
config.intra_op_num_threads = cpu_num_thread
config.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
providers=["CPUExecutionProvider"]
ort_session = ort.InferenceSession(model.SerializeToString(),providers=providers,sess_options=config)
input = np.fromfile("./tmp/inputdata_fp32.bin", dtype='float32').reshape(8,519,519,3)
ort_inputs = {ort_session.get_inputs()[0].name: input}
#获取所有节点输出
outputs = [x.name for x in ort_session.get_outputs()] #ort_session.get_outputs()[0].name是原模型的单一输出
ort_outs = ort_session.run(output_names=outputs, input_feed=ort_inputs)
# 生成字典,便于查找层对应输出
# ort_outs = OrderedDict(zip(outputs, ort_outs))
result_outs = OrderedDict(zip(outputs, ort_outs))
# print("result_outs.keys(): ", result_outs.keys())
for key in result_outs.keys():
if "aspp/conv_1x1_concat/Relu:0" in key:
print("key: ", key)
# print(f'key --> {key} : value --> {result_outs[key]}')
print("value shape: ", result_outs[key].shape)
stc_out = np.fromfile("./tmp/test_dump/stc_cpu/aspp.bin", dtype='float32').reshape(result_outs[key].shape)
# verify_atol_rtol(stc_out, result_outs[key], atol=1e-6, rtol=5e-6)
verify_atol_rtol(stc_out, result_outs[key], atol=1e-2, rtol=5e-2)
# verify_atol_rtol(stc_out, result_outs[key], atol=0, rtol=0)
break