pytorch onnx模型转换java调用示例
python torch训练一个神经网络,用来进行简单的mnist数字预测!并将训练后的模型存为onnx格式:
代码如下:
import torch import torch.nn as nn import torch.optim as optim import onnxruntime as ort import numpy as np import torch from torchvision import datasets, transforms # 定义超参数 batch_size = 64 learning_rate = 0.01 num_epochs = 1 # 数据加载和预处理 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True) test_dataset = datasets.MNIST(root='./data', train=False, transform=transform) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False) # 定义神经网络模型 class SimpleNN(nn.Module): def __init__(self): super(SimpleNN, self).__init__() self.fc1 = nn.Linear(28 * 28, 128) self.fc2 = nn.Linear(128, 64) self.fc3 = nn.Linear(64, 10) def forward(self, x): x = x.view(-1, 28 * 28) x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) x = self.fc3(x) return x model = SimpleNN() def train(): # 定义损失函数和优化器 criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=learning_rate) # 训练模型 for epoch in range(num_epochs): model.train() for batch_idx, (data, target) in enumerate(train_loader): optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() if batch_idx % 100 == 0: print(f'Epoch {epoch + 1}/{num_epochs}, Batch {batch_idx}, Loss: {loss.item()}') # 测试模型 model.eval() correct = 0 total = 0 with torch.no_grad(): for data, target in test_loader: output = model(data) _, predicted = torch.max(output.data, 1) total += target.size(0) correct += (predicted == target).sum().item() print(f'Accuracy of the model on the test images: {100 * correct / total}%') # 推理示例 sample_data, sample_target = next(iter(test_loader)) sample_output = model(sample_data) _, sample_predicted = torch.max(sample_output.data, 1) print(f'Predicted: {sample_predicted[:10]}') print(f'Actual: {sample_target[:10]}') # 保存模型到本地 torch.save(model.state_dict(), 'simple_nn.pth') print('Model saved to simple_nn.pth') # 转换为ONNX格式并保存 dummy_input = torch.randn(1, 1, 28, 28) # 创建一个dummy输入 torch.onnx.export(model, dummy_input, 'simple_nn.onnx', input_names=['input'], output_names=['output']) print('Model converted to ONNX and saved to simple_nn.onnx') def inference(): # 数据加载和预处理 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) test_dataset = datasets.MNIST(root='./data', train=False, transform=transform) test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1, shuffle=False) # 加载ONNX模型 onnx_model_path = 'simple_nn.onnx' ort_session = ort.InferenceSession(onnx_model_path) # 定义一个函数来进行推理 def to_numpy(tensor): return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() def infer_onnx_model(ort_session, data): ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(data)} ort_outs = ort_session.run(None, ort_inputs) return ort_outs # 推理示例 sample_data, sample_target = next(iter(test_loader)) sample_data = sample_data.view(1, 1, 28, 28) # 调整输入形状 onnx_output = infer_onnx_model(ort_session, sample_data) onnx_predicted = np.argmax(onnx_output[0], axis=1) print(f'Predicted: {onnx_predicted[0]}') print(f'Actual: {sample_target.item()}') if __name__ == '__main__': train() inference()
代码输出:
Epoch 1/1, Batch 0, Loss: 2.3018319606781006 Epoch 1/1, Batch 100, Loss: 1.8857725858688354 Epoch 1/1, Batch 200, Loss: 1.0029046535491943 Epoch 1/1, Batch 300, Loss: 0.6656786203384399 Epoch 1/1, Batch 400, Loss: 0.641338586807251 Epoch 1/1, Batch 500, Loss: 0.5250198841094971 Epoch 1/1, Batch 600, Loss: 0.5605880618095398 Epoch 1/1, Batch 700, Loss: 0.5747233629226685 Epoch 1/1, Batch 800, Loss: 0.49430033564567566 Epoch 1/1, Batch 900, Loss: 0.28630945086479187 Accuracy of the model on the test images: 90.38% Predicted: tensor([7, 2, 1, 0, 4, 1, 4, 9, 6, 9]) Actual: tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9]) Model saved to simple_nn.pth Model converted to ONNX and saved to simple_nn.onnx Predicted: 7 Actual: 7
接下来,我们在java中使用该onnx模型进行预测:
代码如下:
import ai.onnxruntime.*; import java.nio.FloatBuffer; import java.util.Collections; import java.util.Map; public class App { public static void main(String[] args) throws OrtException { // 加载ONNX模型 String modelPath = "D:\\source\\pythonProject\\simple_nn.onnx"; // "simple_nn.onnx"; OrtEnvironment env = OrtEnvironment.getEnvironment(); OrtSession.SessionOptions opts = new OrtSession.SessionOptions(); OrtSession session = env.createSession(modelPath, opts); // 构造随机输入 float[] inputData = new float[1 * 1 * 28 * 28]; for (int i = 0; i < inputData.length; i++) { inputData[i] = (float) Math.random(); } // 创建输入张量 OnnxTensor inputTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputData), new long[]{1, 1, 28, 28}); // 运行推理 Map<String, OnnxTensor> inputs = Collections.singletonMap(session.getInputNames().iterator().next(), inputTensor); OrtSession.Result result = session.run(inputs); // 获取输出 float[][] output = (float[][]) result.get(0).getValue(); int predictedLabel = argMax(output[0]); System.out.println("Predicted Label: " + predictedLabel); // 释放资源 inputTensor.close(); session.close(); env.close(); } // 获取最大值的索引 private static int argMax(float[] array) { int maxIndex = 0; for (int i = 1; i < array.length; i++) { if (array[i] > array[maxIndex]) { maxIndex = i; } } return maxIndex; } }
运行输出:
Predicted Label: 8
附:
pom.xml文件
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> <modelVersion>4.0.0</modelVersion> <groupId>org.example</groupId> <artifactId>test_onnx</artifactId> <version>1.0-SNAPSHOT</version> <packaging>jar</packaging> <name>test_onnx</name> <url>http://maven.apache.org</url> <properties> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> </properties> <dependencies> <dependency> <groupId>junit</groupId> <artifactId>junit</artifactId> <version>3.8.1</version> <scope>test</scope> </dependency> <!-- ONNX Runtime dependency --> <dependency> <groupId>com.microsoft.onnxruntime</groupId> <artifactId>onnxruntime</artifactId> <version>1.15.1</version> </dependency> </dependencies> </project>