pytorch onnx模型转换java调用示例
python torch训练一个神经网络,用来进行简单的mnist数字预测!并将训练后的模型存为onnx格式:
代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | import torch import torch.nn as nn import torch.optim as optim import onnxruntime as ort import numpy as np import torch from torchvision import datasets, transforms # 定义超参数 batch_size = 64 learning_rate = 0.01 num_epochs = 1 # 数据加载和预处理 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(( 0.1307 ,), ( 0.3081 ,)) ]) train_dataset = datasets.MNIST(root = './data' , train = True , transform = transform, download = True ) test_dataset = datasets.MNIST(root = './data' , train = False , transform = transform) train_loader = torch.utils.data.DataLoader(dataset = train_dataset, batch_size = batch_size, shuffle = True ) test_loader = torch.utils.data.DataLoader(dataset = test_dataset, batch_size = batch_size, shuffle = False ) # 定义神经网络模型 class SimpleNN(nn.Module): def __init__( self ): super (SimpleNN, self ).__init__() self .fc1 = nn.Linear( 28 * 28 , 128 ) self .fc2 = nn.Linear( 128 , 64 ) self .fc3 = nn.Linear( 64 , 10 ) def forward( self , x): x = x.view( - 1 , 28 * 28 ) x = torch.relu( self .fc1(x)) x = torch.relu( self .fc2(x)) x = self .fc3(x) return x model = SimpleNN() def train(): # 定义损失函数和优化器 criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr = learning_rate) # 训练模型 for epoch in range (num_epochs): model.train() for batch_idx, (data, target) in enumerate (train_loader): optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() if batch_idx % 100 = = 0 : print (f 'Epoch {epoch + 1}/{num_epochs}, Batch {batch_idx}, Loss: {loss.item()}' ) # 测试模型 model. eval () correct = 0 total = 0 with torch.no_grad(): for data, target in test_loader: output = model(data) _, predicted = torch. max (output.data, 1 ) total + = target.size( 0 ) correct + = (predicted = = target). sum ().item() print (f 'Accuracy of the model on the test images: {100 * correct / total}%' ) # 推理示例 sample_data, sample_target = next ( iter (test_loader)) sample_output = model(sample_data) _, sample_predicted = torch. max (sample_output.data, 1 ) print (f 'Predicted: {sample_predicted[:10]}' ) print (f 'Actual: {sample_target[:10]}' ) # 保存模型到本地 torch.save(model.state_dict(), 'simple_nn.pth' ) print ( 'Model saved to simple_nn.pth' ) # 转换为ONNX格式并保存 dummy_input = torch.randn( 1 , 1 , 28 , 28 ) # 创建一个dummy输入 torch.onnx.export(model, dummy_input, 'simple_nn.onnx' , input_names = [ 'input' ], output_names = [ 'output' ]) print ( 'Model converted to ONNX and saved to simple_nn.onnx' ) def inference(): # 数据加载和预处理 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(( 0.1307 ,), ( 0.3081 ,)) ]) test_dataset = datasets.MNIST(root = './data' , train = False , transform = transform) test_loader = torch.utils.data.DataLoader(dataset = test_dataset, batch_size = 1 , shuffle = False ) # 加载ONNX模型 onnx_model_path = 'simple_nn.onnx' ort_session = ort.InferenceSession(onnx_model_path) # 定义一个函数来进行推理 def to_numpy(tensor): return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() def infer_onnx_model(ort_session, data): ort_inputs = {ort_session.get_inputs()[ 0 ].name: to_numpy(data)} ort_outs = ort_session.run( None , ort_inputs) return ort_outs # 推理示例 sample_data, sample_target = next ( iter (test_loader)) sample_data = sample_data.view( 1 , 1 , 28 , 28 ) # 调整输入形状 onnx_output = infer_onnx_model(ort_session, sample_data) onnx_predicted = np.argmax(onnx_output[ 0 ], axis = 1 ) print (f 'Predicted: {onnx_predicted[0]}' ) print (f 'Actual: {sample_target.item()}' ) if __name__ = = '__main__' : train() inference() |
代码输出:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | Epoch 1 / 1 , Batch 0 , Loss: 2.3018319606781006 Epoch 1 / 1 , Batch 100 , Loss: 1.8857725858688354 Epoch 1 / 1 , Batch 200 , Loss: 1.0029046535491943 Epoch 1 / 1 , Batch 300 , Loss: 0.6656786203384399 Epoch 1 / 1 , Batch 400 , Loss: 0.641338586807251 Epoch 1 / 1 , Batch 500 , Loss: 0.5250198841094971 Epoch 1 / 1 , Batch 600 , Loss: 0.5605880618095398 Epoch 1 / 1 , Batch 700 , Loss: 0.5747233629226685 Epoch 1 / 1 , Batch 800 , Loss: 0.49430033564567566 Epoch 1 / 1 , Batch 900 , Loss: 0.28630945086479187 Accuracy of the model on the test images: 90.38 % Predicted: tensor([ 7 , 2 , 1 , 0 , 4 , 1 , 4 , 9 , 6 , 9 ]) Actual: tensor([ 7 , 2 , 1 , 0 , 4 , 1 , 4 , 9 , 5 , 9 ]) Model saved to simple_nn.pth Model converted to ONNX and saved to simple_nn.onnx Predicted: 7 Actual: 7 |
接下来,我们在java中使用该onnx模型进行预测:
代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 | import ai.onnxruntime. * ; import java.nio.FloatBuffer; import java.util.Collections; import java.util. Map ; public class App { public static void main(String[] args) throws OrtException { / / 加载ONNX模型 String modelPath = "D:\\source\\pythonProject\\simple_nn.onnx" ; / / "simple_nn.onnx" ; OrtEnvironment env = OrtEnvironment.getEnvironment(); OrtSession.SessionOptions opts = new OrtSession.SessionOptions(); OrtSession session = env.createSession(modelPath, opts); / / 构造随机输入 float [] inputData = new float [ 1 * 1 * 28 * 28 ]; for ( int i = 0 ; i < inputData.length; i + + ) { inputData[i] = ( float ) Math.random(); } / / 创建输入张量 OnnxTensor inputTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputData), new long []{ 1 , 1 , 28 , 28 }); / / 运行推理 Map <String, OnnxTensor> inputs = Collections.singletonMap(session.getInputNames().iterator(). next (), inputTensor); OrtSession.Result result = session.run(inputs); / / 获取输出 float [][] output = ( float [][]) result.get( 0 ).getValue(); int predictedLabel = argMax(output[ 0 ]); System.out.println( "Predicted Label: " + predictedLabel); / / 释放资源 inputTensor.close(); session.close(); env.close(); } / / 获取最大值的索引 private static int argMax( float [] array) { int maxIndex = 0 ; for ( int i = 1 ; i < array.length; i + + ) { if (array[i] > array[maxIndex]) { maxIndex = i; } } return maxIndex; } } |
运行输出:
1 | Predicted Label: 8 |
附:
pom.xml文件
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | <project xmlns = "http://maven.apache.org/POM/4.0.0" xmlns:xsi = "http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation = "http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd" > <modelVersion> 4.0 . 0 < / modelVersion> <groupId>org.example< / groupId> <artifactId>test_onnx< / artifactId> <version> 1.0 - SNAPSHOT< / version> <packaging>jar< / packaging> <name>test_onnx< / name> <url>http: / / maven.apache.org< / url> <properties> <project.build.sourceEncoding>UTF - 8 < / project.build.sourceEncoding> < / properties> <dependencies> <dependency> <groupId>junit< / groupId> <artifactId>junit< / artifactId> <version> 3.8 . 1 < / version> <scope>test< / scope> < / dependency> <! - - ONNX Runtime dependency - - > <dependency> <groupId>com.microsoft.onnxruntime< / groupId> <artifactId>onnxruntime< / artifactId> <version> 1.15 . 1 < / version> < / dependency> < / dependencies> < / project> |
标签:
机器学习
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 记一次.NET内存居高不下排查解决与启示
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· DeepSeek 开源周回顾「GitHub 热点速览」
2022-06-27 Linux恶意挖矿软件SkidMap分析——伪造CPU占用
2022-06-27 CPU/GPU挖矿占用率统计
2022-06-27 CUDA ---- Hello World From GPU
2021-06-27 sqlmap的使用 ----常用tamper模块,TODO,绕过WAF的测试
2021-06-27 绕过WAF的扫描——模拟爬虫
2017-06-27 DNS隧道工具使用 不过其网络传输速度限制较大
2017-06-27 DNS tunnel的原理及实战