:)torch转onnx总结--|
torch->onnx
参考:
参考连接:https://blog.csdn.net/cxx654/article/details/123011332
1 安装 onnx
>python -m pip install onnx
>python -m pip install onnxruntime
调用onnx
ModuleNotFoundError: No module named 'onnx.defs'
2 原理
I onnx是通用规则的部署格式
所以oneflow ,pytorch等等都需要转onnx。
II 调用export接口转onnx
onnxruntime执行
3 代码工程
导入模块
import os import numpy import torch from torch import nn import torch.optim as opt import onnxruntime as ort import onnx
3.1 定义NN
1)def forward中的输入x 训练时输入多少维度就是多少维度。
比如for 循环内x的维度为 [batchsize=10, channel =3, height=256, weight=256]
2)def forward中的输入x 不需要考虑循环batchsize , 不用取单次batchsize=1.
3)输出的return [batchsize 还是10, ...,...,...]
4)CONV与linear之间 有个拍平的过程,
self.flat = nn.Flatten(start_dim=1)
代码
class SC(nn.Module): """ input shape = [n, 3, 8, 8] conv-> """ def __init__(self): super(SC, self).__init__() self.cv1 = nn.Conv2d(in_channels=3, out_channels=10, kernel_size=3, bias=True) self.flat = nn.Flatten(start_dim=1) # 2分类 self.fc1 = nn.Linear(360, 2) def forward(self, x): # 1 输入batch原来一样 print("entry========", x.shape) # 2 实际计算时会batch=1计算结果 # 3 返回原来batch,各个结果。 x_ = self.cv1(x) x_ = self.flat(x_) x_ = self.fc1(x_) return x_
5)定义NN 与 训练教程 与训练过程是分开的
比如以上代码为定义NN
训练教程和训练过程 不定义在类中
# 训练 配置 opter = opt.Adam(sc.parameters(), lr=0.03) loss = nn.CrossEntropyLoss() # for 训练过程
3.2 保存为输入尺寸固定的onnx
优点:量化时候比较容易
注意:1) torch.load(model) model的类一定要导入或就在同文件。 和torch中的pickel策略有关。
3.3 保存为动态尺寸的onnx
带有实验 给固定参数的onnx ,输入batchsize=另外的数值, 会显示错误
batchsize = 100 x = torch.randn(size=(batchsize, 3, 8, 8)) model_out = sc(x) print(model_out) # 什么时候 with torch.no_grad()??? os.makedirs("zandir", exist_ok=True) def gen_static(): torch.onnx.export(sc, x, "zandir\sc.onnx", opset_version=11, input_names=["inp"], output_names=["opt"]) # 生成动态尺寸 def gen_dynamic(): dynamic_axes = {"inp": {0: "batchsize"}, "opt": {0: "batchsize"}} torch.onnx.export(sc, x, "zandir\sc_dyn.onnx", opset_version=11, dynamic_axes=dynamic_axes, input_names=["inp"], output_names=["opt"]) # 测试导出静态模型传入非100 batchsize 是否报错 def load_static(): model = onnx.load("zandir\sc.onnx") r = onnx.checker.check_model(model) print("r===", r) def load_and_run_onnx_static(): model = onnx.load("zandir\sc.onnx") session = ort.InferenceSession("zandir\sc.onnx") # x = numpy.random.randn(100,3,8,8).astype(numpy.float32) x = numpy.random.randn(16,3,8,8).astype(numpy.float32) inputdata = {"inp": x} output = session.run(None, inputdata) print(len(output)) print(output[0]) def load_and_run_onnx_dyn(): session = ort.InferenceSession("zandir\sc_dyn.onnx") # x = numpy.random.randn(100,3,8,8).astype(numpy.float32) x = numpy.random.randn(16,3,8,8).astype(numpy.float32) inputdata = {"inp": x} output = session.run(None, inputdata) print(len(output)) print(output[0]) gen_static() gen_dynamic() load_static() # load_and_run_onnx_static() load_and_run_onnx_dyn()