pytorch onnx模型转换java调用示例

pytorch的模型可以转为hf、onnx、trt,都是什么格式?几种转换有啥区别?

PyTorch 模型可以转换为多个不同的格式以适应不同的应用场景和优化需求。以下是一些常见的转换格式及其特点:

1. PyTorch 模型(.pt 或 .pth

  • 格式:原生 PyTorch 模型保存为 .pt 或 .pth 文件。
  • 用途:用于加载和推理使用 PyTorch 库模型。
  • 优缺点:
    • 优点:原生支持所有 PyTorch 功能,易于在 PyTorch 环境中调试和运行。
    • 缺点:依赖 PyTorch 运行时,可能不适用于资源有限的设备(如嵌入式系统)。

2. Hugging Face Transformers(.bin 或 .h5

  • 格式:转换后的模型保存为 .bin(PyTorch)或 .h5(TensorFlow),这些文件格式用于 Hugging Face Transformers 库。
  • 用途:主要用于 NLP(自然语言处理)任务,适用于使用 Hugging Face Transformers 库的环境。
  • 优缺点:
    • 优点:利用 Hugging Face 的丰富的预训练模型和强大的 API。
    • 缺点:可能需要适应 Transformer's API 和数据处理规范。

3. ONNX(Open Neural Network Exchange,.onnx

  • 格式:转换后的 ONNX 保存为 .onnx 文件。
  • 用途:ONNX 格式广泛用于模型交换和跨平台部署,支持多个深度学习框架(如 PyTorch、TensorFlow、Caffe2)。
  • 优缺点:
    • 优点:跨平台兼容性,可以在不同的深度学习框架和硬件上部署(如手机、Web、嵌入式设备)。
    • 缺点:可能需要一些模型转换和 API 接口的适配,某些高级功能可能无法完全支持。

4. TensorRT(TensorRT Engine,.trt 或 .plan

  • 格式:转换后的 TensorRT 引擎保存为 .trt 或 .plan 文件。
  • 用途:用于在 NVIDIA GPU 上进行高性能推理。
  • 优缺点:
    • 优点:显著提高推理速度和效率,尤其是在 NVIDIA 硬件(如 GPU)上。
    • 缺点:仅适用于 NVIDIA 硬件,可能涉及较复杂的优化和硬件特性依赖。

总结

不同的格式有各自的用途和优缺点,选择适合特定应用的格式非常重要。PyTorch 原生格式适用于调试和开发,Hugging Face Transformers 格式适用于 NLP 应用,ONNX 格式适用于跨平台模型部署,而 TensorRT 格式则特别适用于需要高性能推理的 NVIDIA 硬件。选择适合的格式通常取决于目标设备、性能需求以及开发环境。

 

示例

python torch训练一个神经网络,用来进行简单的mnist数字预测!并将训练后的模型存为onnx格式:

代码如下:

 

import torch
import torch.nn as nn
import torch.optim as optim
import onnxruntime as ort
import numpy as np
import torch
from torchvision import datasets, transforms

# 定义超参数
batch_size = 64
learning_rate = 0.01
num_epochs = 1

# 数据加载和预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)


# 定义神经网络模型
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = SimpleNN()

def train():
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)

    # 训练模型
    for epoch in range(num_epochs):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            if batch_idx % 100 == 0:
                print(f'Epoch {epoch + 1}/{num_epochs}, Batch {batch_idx}, Loss: {loss.item()}')

    # 测试模型
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    print(f'Accuracy of the model on the test images: {100 * correct / total}%')

    # 推理示例
    sample_data, sample_target = next(iter(test_loader))
    sample_output = model(sample_data)
    _, sample_predicted = torch.max(sample_output.data, 1)

    print(f'Predicted: {sample_predicted[:10]}')
    print(f'Actual: {sample_target[:10]}')


    # 保存模型到本地
    torch.save(model.state_dict(), 'simple_nn.pth')
    print('Model saved to simple_nn.pth')

    # 转换为ONNX格式并保存
    dummy_input = torch.randn(1, 1, 28, 28)  # 创建一个dummy输入
    torch.onnx.export(model, dummy_input, 'simple_nn.onnx', input_names=['input'], output_names=['output'])
    print('Model converted to ONNX and saved to simple_nn.onnx')


def inference():
    # 数据加载和预处理
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)

    # 加载ONNX模型
    onnx_model_path = 'simple_nn.onnx'
    ort_session = ort.InferenceSession(onnx_model_path)

    # 定义一个函数来进行推理
    def to_numpy(tensor):
        return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

    def infer_onnx_model(ort_session, data):
        ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(data)}
        ort_outs = ort_session.run(None, ort_inputs)
        return ort_outs

    # 推理示例
    sample_data, sample_target = next(iter(test_loader))
    sample_data = sample_data.view(1, 1, 28, 28)  # 调整输入形状
    onnx_output = infer_onnx_model(ort_session, sample_data)
    onnx_predicted = np.argmax(onnx_output[0], axis=1)

    print(f'Predicted: {onnx_predicted[0]}')
    print(f'Actual: {sample_target.item()}')


if __name__ == '__main__':
    train()
    inference()

  

代码输出:

Epoch 1/1, Batch 0, Loss: 2.3018319606781006
Epoch 1/1, Batch 100, Loss: 1.8857725858688354
Epoch 1/1, Batch 200, Loss: 1.0029046535491943
Epoch 1/1, Batch 300, Loss: 0.6656786203384399
Epoch 1/1, Batch 400, Loss: 0.641338586807251
Epoch 1/1, Batch 500, Loss: 0.5250198841094971
Epoch 1/1, Batch 600, Loss: 0.5605880618095398
Epoch 1/1, Batch 700, Loss: 0.5747233629226685
Epoch 1/1, Batch 800, Loss: 0.49430033564567566
Epoch 1/1, Batch 900, Loss: 0.28630945086479187
Accuracy of the model on the test images: 90.38%
Predicted: tensor([7, 2, 1, 0, 4, 1, 4, 9, 6, 9])
Actual: tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9])
Model saved to simple_nn.pth
Model converted to ONNX and saved to simple_nn.onnx
Predicted: 7
Actual: 7

  

接下来,我们在java中使用该onnx模型进行预测:

代码如下:

import ai.onnxruntime.*;
import java.nio.FloatBuffer;
import java.util.Collections;
import java.util.Map;

public class App {

    public static void main(String[] args) throws OrtException {
        // 加载ONNX模型
        String modelPath = "D:\\source\\pythonProject\\simple_nn.onnx"; // "simple_nn.onnx";
        OrtEnvironment env = OrtEnvironment.getEnvironment();
        OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
        OrtSession session = env.createSession(modelPath, opts);

        // 构造随机输入
        float[] inputData = new float[1 * 1 * 28 * 28];
        for (int i = 0; i < inputData.length; i++) {
            inputData[i] = (float) Math.random();
        }

        // 创建输入张量
        OnnxTensor inputTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputData), new long[]{1, 1, 28, 28});

        // 运行推理
        Map<String, OnnxTensor> inputs = Collections.singletonMap(session.getInputNames().iterator().next(), inputTensor);
        OrtSession.Result result = session.run(inputs);

        // 获取输出
        float[][] output = (float[][]) result.get(0).getValue();
        int predictedLabel = argMax(output[0]);

        System.out.println("Predicted Label: " + predictedLabel);

        // 释放资源
        inputTensor.close();
        session.close();
        env.close();
    }

    // 获取最大值的索引
    private static int argMax(float[] array) {
        int maxIndex = 0;
        for (int i = 1; i < array.length; i++) {
            if (array[i] > array[maxIndex]) {
                maxIndex = i;
            }
        }
        return maxIndex;
    }
}

  

运行输出:

Predicted Label: 8

  

附:

pom.xml文件

<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
  xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
  <modelVersion>4.0.0</modelVersion>

  <groupId>org.example</groupId>
  <artifactId>test_onnx</artifactId>
  <version>1.0-SNAPSHOT</version>
  <packaging>jar</packaging>

  <name>test_onnx</name>
  <url>http://maven.apache.org</url>

  <properties>
    <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
  </properties>

  <dependencies>
    <dependency>
      <groupId>junit</groupId>
      <artifactId>junit</artifactId>
      <version>3.8.1</version>
      <scope>test</scope>
    </dependency>

    <!-- ONNX Runtime dependency -->
    <dependency>
      <groupId>com.microsoft.onnxruntime</groupId>
      <artifactId>onnxruntime</artifactId>
      <version>1.15.1</version>
    </dependency>
  </dependencies>
</project>

  

 

posted @ 2024-06-27 17:26  bonelee  阅读(36)  评论(0编辑  收藏  举报