Python将模型参数文件(.pth/.pkl等)转换为ONNX格式

import torch
import pickle
import numpy as np

model_path = r'./sVGG16.pkl'                           # 模型参数路径  
dummy_input = torch.randn(1, 3, 256, 256)  # 先随机一个模型输入的数据

model = sVGG16()                          # 定义模型结构,此处是我自己设计的模型
checkpoing = torch.load(model_path, 'cpu')  # 导入模型参数
model.load_state_dict(checkpoing)           # 将模型参数赋予自定义的模型

torch.onnx.export(model, dummy_input, "model_best.onnx",verbose=True) # 将模型保存成.onnx格式

  

posted @ 2022-04-25 15:26  小丑_jk  阅读(1280)  评论(0编辑  收藏  举报