pytorch构建并保存模型(.pth) 转化为 torchscript(.pt), 导出为onnx格式
pytorch(.pth)模型转化为 torchscript(.pt), 导出为onnx格式
1 .pth模型转换为.pt模型
import torch import torchvision from models import fcn model=torchvision.models.vgg16() state_dict = torch.load("./checkpoint-epoch100.pth") #print(state_dict) model.load_state_dict(state_dict,False) model.eval() x = torch.rand(1,3,128,128) ts = torch.jit.trace(model, x) ts.save('fcn_vgg16.net')
注意很多人在转换的时候报错是因为:model.load_state_dict(state_dict)后面没用False参数
2. .pth模型转化为.onnx模型
如需使用opencv来加载模型,则需将.pth转化为.onnx格式的模型。
a.先安装onnx,使用命令:pip install onnx
;
b.使用以下命令转为.onnx模型
import io import torch import torch.onnx import torchvision from models import fcn device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") def test(): model=torchvision.models.vgg16() pthfile = r'./checkpoint-epoch100.pth' loaded_model = torch.load(pthfile, map_location='cpu') # try: # loaded_model.eval() # except AttributeError as error: # print(error) #model.load_state_dict(loaded_model['state_dict']) # model = model.to(device) #data type nchw dummy_input1 = torch.randn(1, 3, 244, 244) # dummy_input2 = torch.randn(1, 3, 64, 64) # dummy_input3 = torch.randn(1, 3, 64, 64) input_names = [ "actual_input_1"] output_names = [ "output1" ] # torch.onnx.export(model, (dummy_input1, dummy_input2, dummy_input3), "C3AE.onnx", verbose=True, input_names=input_names, output_names=output_names) torch.onnx.export(model, dummy_input1, "fcn.onnx", verbose=True, input_names=input_names, output_names=output_names) if __name__ == "__main__": test()
====================================================
import torch import torch.nn as nn from torch.autograd import Function import onnx import torch.onnx class TinyNet(nn.Module): def __init__(self): super(TinyNet, self).__init__() self.abs = torch.abs def forward(self, x): x = self.abs(x) return x model = TinyNet() input = torch.FloatTensor([[-1, -2, 3],[-4, -5, 6]]) input_names = ["input_0"] output_names = ["output0"] torch.onnx.export(model, (input,), 'tinynet.onnx', opset_version=19, verbose=True, input_names=input_names, output_names=output_names) print(onnx.load('tinynet.onnx'))
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 开源Multi-agent AI智能体框架aevatar.ai,欢迎大家贡献代码
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY