pytorch onnx模型转换java调用示例

pytorch的模型可以转为hf、onnx、trt,都是什么格式?几种转换有啥区别?

PyTorch 模型可以转换为多个不同的格式以适应不同的应用场景和优化需求。以下是一些常见的转换格式及其特点:

1. PyTorch 模型(.pt 或 .pth

  • 格式:原生 PyTorch 模型保存为 .pt 或 .pth 文件。
  • 用途:用于加载和推理使用 PyTorch 库模型。
  • 优缺点:
    • 优点:原生支持所有 PyTorch 功能,易于在 PyTorch 环境中调试和运行。
    • 缺点:依赖 PyTorch 运行时,可能不适用于资源有限的设备(如嵌入式系统)。

2. Hugging Face Transformers(.bin 或 .h5

  • 格式:转换后的模型保存为 .bin(PyTorch)或 .h5(TensorFlow),这些文件格式用于 Hugging Face Transformers 库。
  • 用途:主要用于 NLP(自然语言处理)任务,适用于使用 Hugging Face Transformers 库的环境。
  • 优缺点:
    • 优点:利用 Hugging Face 的丰富的预训练模型和强大的 API。
    • 缺点:可能需要适应 Transformer's API 和数据处理规范。

3. ONNX(Open Neural Network Exchange,.onnx

  • 格式:转换后的 ONNX 保存为 .onnx 文件。
  • 用途:ONNX 格式广泛用于模型交换和跨平台部署,支持多个深度学习框架(如 PyTorch、TensorFlow、Caffe2)。
  • 优缺点:
    • 优点:跨平台兼容性,可以在不同的深度学习框架和硬件上部署(如手机、Web、嵌入式设备)。
    • 缺点:可能需要一些模型转换和 API 接口的适配,某些高级功能可能无法完全支持。

4. TensorRT(TensorRT Engine,.trt 或 .plan

  • 格式:转换后的 TensorRT 引擎保存为 .trt 或 .plan 文件。
  • 用途:用于在 NVIDIA GPU 上进行高性能推理。
  • 优缺点:
    • 优点:显著提高推理速度和效率,尤其是在 NVIDIA 硬件(如 GPU)上。
    • 缺点:仅适用于 NVIDIA 硬件,可能涉及较复杂的优化和硬件特性依赖。

总结

不同的格式有各自的用途和优缺点,选择适合特定应用的格式非常重要。PyTorch 原生格式适用于调试和开发,Hugging Face Transformers 格式适用于 NLP 应用,ONNX 格式适用于跨平台模型部署,而 TensorRT 格式则特别适用于需要高性能推理的 NVIDIA 硬件。选择适合的格式通常取决于目标设备、性能需求以及开发环境。

 

示例

python torch训练一个神经网络,用来进行简单的mnist数字预测!并将训练后的模型存为onnx格式:

代码如下:

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch
import torch.nn as nn
import torch.optim as optim
import onnxruntime as ort
import numpy as np
import torch
from torchvision import datasets, transforms
 
# 定义超参数
batch_size = 64
learning_rate = 0.01
num_epochs = 1
 
# 数据加载和预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
 
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)
 
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
 
 
# 定义神经网络模型
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
 
    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x
 
 
model = SimpleNN()
 
def train():
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)
 
    # 训练模型
    for epoch in range(num_epochs):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
 
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch + 1}/{num_epochs}, Batch {batch_idx}, Loss: {loss.item()}')
 
    # 测试模型
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
 
    print(f'Accuracy of the model on the test images: {100 * correct / total}%')
 
    # 推理示例
    sample_data, sample_target = next(iter(test_loader))
    sample_output = model(sample_data)
    _, sample_predicted = torch.max(sample_output.data, 1)
 
    print(f'Predicted: {sample_predicted[:10]}')
    print(f'Actual: {sample_target[:10]}')
 
 
    # 保存模型到本地
    torch.save(model.state_dict(), 'simple_nn.pth')
    print('Model saved to simple_nn.pth')
 
    # 转换为ONNX格式并保存
    dummy_input = torch.randn(1, 1, 28, 28# 创建一个dummy输入
    torch.onnx.export(model, dummy_input, 'simple_nn.onnx', input_names=['input'], output_names=['output'])
    print('Model converted to ONNX and saved to simple_nn.onnx')
 
 
def inference():
    # 数据加载和预处理
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
 
    test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)
 
    # 加载ONNX模型
    onnx_model_path = 'simple_nn.onnx'
    ort_session = ort.InferenceSession(onnx_model_path)
 
    # 定义一个函数来进行推理
    def to_numpy(tensor):
        return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
 
    def infer_onnx_model(ort_session, data):
        ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(data)}
        ort_outs = ort_session.run(None, ort_inputs)
        return ort_outs
 
    # 推理示例
    sample_data, sample_target = next(iter(test_loader))
    sample_data = sample_data.view(1, 1, 28, 28# 调整输入形状
    onnx_output = infer_onnx_model(ort_session, sample_data)
    onnx_predicted = np.argmax(onnx_output[0], axis=1)
 
    print(f'Predicted: {onnx_predicted[0]}')
    print(f'Actual: {sample_target.item()}')
 
 
if __name__ == '__main__':
    train()
    inference()

  

代码输出:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
Epoch 1/1, Batch 0, Loss: 2.3018319606781006
Epoch 1/1, Batch 100, Loss: 1.8857725858688354
Epoch 1/1, Batch 200, Loss: 1.0029046535491943
Epoch 1/1, Batch 300, Loss: 0.6656786203384399
Epoch 1/1, Batch 400, Loss: 0.641338586807251
Epoch 1/1, Batch 500, Loss: 0.5250198841094971
Epoch 1/1, Batch 600, Loss: 0.5605880618095398
Epoch 1/1, Batch 700, Loss: 0.5747233629226685
Epoch 1/1, Batch 800, Loss: 0.49430033564567566
Epoch 1/1, Batch 900, Loss: 0.28630945086479187
Accuracy of the model on the test images: 90.38%
Predicted: tensor([7, 2, 1, 0, 4, 1, 4, 9, 6, 9])
Actual: tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9])
Model saved to simple_nn.pth
Model converted to ONNX and saved to simple_nn.onnx
Predicted: 7
Actual: 7

  

接下来,我们在java中使用该onnx模型进行预测:

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import ai.onnxruntime.*;
import java.nio.FloatBuffer;
import java.util.Collections;
import java.util.Map;
 
public class App {
 
    public static void main(String[] args) throws OrtException {
        // 加载ONNX模型
        String modelPath = "D:\\source\\pythonProject\\simple_nn.onnx"; // "simple_nn.onnx";
        OrtEnvironment env = OrtEnvironment.getEnvironment();
        OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
        OrtSession session = env.createSession(modelPath, opts);
 
        // 构造随机输入
        float[] inputData = new float[1 * 1 * 28 * 28];
        for (int i = 0; i < inputData.length; i++) {
            inputData[i] = (float) Math.random();
        }
 
        // 创建输入张量
        OnnxTensor inputTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputData), new long[]{1, 1, 28, 28});
 
        // 运行推理
        Map<String, OnnxTensor> inputs = Collections.singletonMap(session.getInputNames().iterator().next(), inputTensor);
        OrtSession.Result result = session.run(inputs);
 
        // 获取输出
        float[][] output = (float[][]) result.get(0).getValue();
        int predictedLabel = argMax(output[0]);
 
        System.out.println("Predicted Label: " + predictedLabel);
 
        // 释放资源
        inputTensor.close();
        session.close();
        env.close();
    }
 
    // 获取最大值的索引
    private static int argMax(float[] array) {
        int maxIndex = 0;
        for (int i = 1; i < array.length; i++) {
            if (array[i] > array[maxIndex]) {
                maxIndex = i;
            }
        }
        return maxIndex;
    }
}

  

运行输出:

1
Predicted Label: 8

  

附:

pom.xml文件

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
  xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
  <modelVersion>4.0.0</modelVersion>
 
  <groupId>org.example</groupId>
  <artifactId>test_onnx</artifactId>
  <version>1.0-SNAPSHOT</version>
  <packaging>jar</packaging>
 
  <name>test_onnx</name>
  <url>http://maven.apache.org</url>
 
  <properties>
    <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
  </properties>
 
  <dependencies>
    <dependency>
      <groupId>junit</groupId>
      <artifactId>junit</artifactId>
      <version>3.8.1</version>
      <scope>test</scope>
    </dependency>
 
    <!-- ONNX Runtime dependency -->
    <dependency>
      <groupId>com.microsoft.onnxruntime</groupId>
      <artifactId>onnxruntime</artifactId>
      <version>1.15.1</version>
    </dependency>
  </dependencies>
</project>

  

 

posted @   bonelee  阅读(425)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 记一次.NET内存居高不下排查解决与启示
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· DeepSeek 开源周回顾「GitHub 热点速览」
历史上的今天:
2022-06-27 Linux恶意挖矿软件SkidMap分析——伪造CPU占用
2022-06-27 CPU/GPU挖矿占用率统计
2022-06-27 CUDA ---- Hello World From GPU
2021-06-27 sqlmap的使用 ----常用tamper模块,TODO,绕过WAF的测试
2021-06-27 绕过WAF的扫描——模拟爬虫
2017-06-27 DNS隧道工具使用 不过其网络传输速度限制较大
2017-06-27 DNS tunnel的原理及实战
点击右上角即可分享
微信分享提示