Python将模型参数文件(.pth/.pkl等)转换为ONNX格式
import torch import pickle import numpy as np model_path = r'./sVGG16.pkl' # 模型参数路径 dummy_input = torch.randn(1, 3, 256, 256) # 先随机一个模型输入的数据 model = sVGG16() # 定义模型结构,此处是我自己设计的模型 checkpoing = torch.load(model_path, 'cpu') # 导入模型参数 model.load_state_dict(checkpoing) # 将模型参数赋予自定义的模型 torch.onnx.export(model, dummy_input, "model_best.onnx",verbose=True) # 将模型保存成.onnx格式