ONNX结构及构建抽取子图

ONNX结构

https://mmdeploy.readthedocs.io/zh-cn/latest/tutorial/05_onnx_model_editing.html

https://blog.csdn.net/u013597931/article/details/84401047

https://zhuanlan.zhihu.com/p/346511883

https://zhuanlan.zhihu.com/p/353613079

https://zhuanlan.zhihu.com/p/350702340

https://zhuanlan.zhihu.com/p/350833729

ONNX抽取子图

https://github.com/onnx/onnx/blob/main/docs/PythonAPIOverview.md
https://blog.csdn.net/u013597931/article/details/84635779

import onnx

input_path = "path/to/the/original/model.onnx"
output_path = "path/to/save/the/extracted/model.onnx"
input_names = ["input_0", "input_1", "input_2"]
output_names = ["output_0", "output_1"]

onnx.utils.extract_model(input_path, output_path, input_names, output_names)

ONNX 获取每层节点名以及每层输入

import os
import glob
import subprocess
import numpy as np
import onnxruntime as ort
import onnx
import sys
import shutil
import copy
import json
from collections import OrderedDict

from utilities.utils import *

if __name__ == "__main__":

    model = onnx.load("./tmp/deeplabv3.onnx")
    # 模型推理
    ori_output = copy.deepcopy(model.graph.output)
    # 输出模型每层的输出
    for node in model.graph.node:
        for output in node.output:
            if output not in ori_output:
                model.graph.output.extend([onnx.ValueInfoProto(name=output)])

    #进行配置
    if ort.get_device()=="CPU":
        config = ort.SessionOptions()
        cpu_num_thread=4
        config.intra_op_num_threads = cpu_num_thread
        config.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
        providers=["CPUExecutionProvider"]
        ort_session = ort.InferenceSession(model.SerializeToString(),providers=providers,sess_options=config)


        input =  np.fromfile("./tmp/inputdata_fp32.bin", dtype='float32').reshape(8,519,519,3)
        ort_inputs = {ort_session.get_inputs()[0].name: input}
        #获取所有节点输出
        outputs = [x.name for x in ort_session.get_outputs()] #ort_session.get_outputs()[0].name是原模型的单一输出
        ort_outs = ort_session.run(output_names=outputs, input_feed=ort_inputs)    
        # 生成字典,便于查找层对应输出
        # ort_outs = OrderedDict(zip(outputs, ort_outs))
        result_outs = OrderedDict(zip(outputs, ort_outs))

        # print("result_outs.keys(): ", result_outs.keys())



        for key in result_outs.keys():
            if "aspp/conv_1x1_concat/Relu:0" in key:
                print("key: ", key)
                # print(f'key --> {key} : value --> {result_outs[key]}')
                print("value shape: ", result_outs[key].shape)
                stc_out = np.fromfile("./tmp/test_dump/stc_cpu/aspp.bin", dtype='float32').reshape(result_outs[key].shape)
                # verify_atol_rtol(stc_out, result_outs[key], atol=1e-6, rtol=5e-6)
                verify_atol_rtol(stc_out, result_outs[key], atol=1e-2, rtol=5e-2)
                # verify_atol_rtol(stc_out, result_outs[key], atol=0, rtol=0)
                break

posted @ 2023-05-19 09:56  michaelchengjl  阅读(44)  评论(0编辑  收藏  举报