python BGE 模型转换为onnx给java调用

最近在做RAG,因为涉及embedding计算,用到了BAAI BGE小模型,但是模型是给python调用的,需要转换为onnx格式给java使用。所以有了下面的探索:

python代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
from transformers import AutoTokenizer, AutoModel
from FlagEmbedding import FlagModel
 
# 初始化模型
model_name_or_path = 'bge-base-zh-v1.5/'
flag_model = FlagModel(model_name_or_path)
flag_model2 = FlagModel(model_name_or_path)
 
# 设置模型为评估模式
flag_model.model.eval()
 
# 创建一个dummy输入
dummy_input_text = "This is a sample text for embedding calculation."
embedding = flag_model2.encode(dummy_input_text)
print("embedding shape:", embedding.shape)
 
inputs = flag_model.tokenizer(dummy_input_text, return_tensors="pt", padding='max_length', truncation=True, max_length=128)
 
# 将输入移动到模型的设备
inputs = {k: v.to(flag_model.device) for k, v in inputs.items()}
 
# 导出模型为ONNX格式
onnx_model_path = "flag_model.onnx"
torch.onnx.export(
    flag_model.model,
    (inputs['input_ids'], inputs['attention_mask']),
    onnx_model_path,
    input_names=['input_ids', 'attention_mask'],
    output_names=['output'],
    dynamic_axes={'input_ids': {0: 'batch_size'}, 'attention_mask': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)
 
print(f"Model has been converted to ONNX and saved to {onnx_model_path}")
 
import torch
import onnxruntime as ort
import numpy as np
from transformers import AutoTokenizer
from FlagEmbedding import FlagModel
 
# 初始化模型和tokenizer
model_name_or_path = 'bge-base-zh-v1.5/'
flag_model = FlagModel(model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
 
# 设置模型为评估模式
flag_model.model.eval()
 
# ONNX模型推理
def onnx_inference(text, pooling_method='cls', normalize_embeddings=True):
    ort_session = ort.InferenceSession("flag_model.onnx")
    inputs = tokenizer(text, return_tensors="pt", padding='max_length', truncation=True, max_length=128)
    input_ids = inputs['input_ids'].cpu().numpy()
    attention_mask = inputs['attention_mask'].cpu().numpy()
    ort_inputs = {
        'input_ids': input_ids,
        'attention_mask': attention_mask
    }
    ort_outs = ort_session.run(None, ort_inputs)
    last_hidden_state = ort_outs[0]
 
    # Apply pooling
    if pooling_method == 'cls':
        print("cls pooling method")
        embeddings = last_hidden_state[:, 0, :]
    elif pooling_method == 'mean':
        print("mean pooling mode")
        s = np.sum(last_hidden_state * np.expand_dims(attention_mask, axis=-1), axis=1)
        d = np.sum(attention_mask, axis=1, keepdims=True)
        embeddings = s / d
 
    # Normalize embeddings if required
    if normalize_embeddings:
        print("normalize embeddings")
        norm = np.linalg.norm(embeddings, axis=-1, keepdims=True)
        embeddings = embeddings / norm
 
    return embeddings
 
 
# 输入文本
texts = [
    "This is a sample text for embedding calculation.",
    "Another example text to test the model.",
    "Yet another text to ensure consistency.",
    "Testing with different lengths and contents.",
    "Final text to verify the ONNX model accuracy.",
    "中文数据测试",
    "随便说点什么吧!反正也只是测试用!",
    "你好!jone!"
]
# 对比结果
for text in texts:
    # original_embedding = original_inference(text)
    original_embedding = flag_model2.encode(text).reshape(1, 768)
    onnx_embedding = onnx_inference(text) # .reshape(768,)
    print("shape compare:", original_embedding.shape, onnx_embedding.shape)
    difference = np.abs(original_embedding - onnx_embedding)
    max_difference = np.max(difference)
    print(f"Text: {text}")
    print(f"Max Difference: {max_difference}")
    print("-" * 50)
 
 
with open("D:\\data\\python_bge.txt", "w", encoding="utf-8") as f:
    for text in texts:
        cls_embedding = flag_model2.encode(text)
        f.write("[")
        f.write(", ".join(map(str, cls_embedding)))
        f.write("]\n")

  

python输出:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
D:\Python\Python312\python.exe D:\source\pythonProject\demo_onnx.py
embedding shape: (768,)
Model has been converted to ONNX and saved to flag_model.onnx
cls pooling method
normalize embeddings
shape compare: (1, 768) (1, 768)
Text: This is a sample text for embedding calculation.
Max Difference: 8.940696716308594e-07
--------------------------------------------------
cls pooling method
normalize embeddings
shape compare: (1, 768) (1, 768)
Text: Another example text to test the model.
Max Difference: 5.364418029785156e-07
--------------------------------------------------
cls pooling method
normalize embeddings
shape compare: (1, 768) (1, 768)
Text: Yet another text to ensure consistency.
Max Difference: 2.384185791015625e-07
--------------------------------------------------
cls pooling method
normalize embeddings
shape compare: (1, 768) (1, 768)
Text: Testing with different lengths and contents.
Max Difference: 5.960464477539062e-07
--------------------------------------------------
cls pooling method
normalize embeddings
shape compare: (1, 768) (1, 768)
Text: Final text to verify the ONNX model accuracy.
Max Difference: 4.76837158203125e-07
--------------------------------------------------
cls pooling method
normalize embeddings
shape compare: (1, 768) (1, 768)
Text: 中文数据测试
Max Difference: 3.5762786865234375e-07
--------------------------------------------------
cls pooling method
normalize embeddings
shape compare: (1, 768) (1, 768)
Text: 随便说点什么吧!反正也只是测试用!
Max Difference: 5.364418029785156e-07
--------------------------------------------------
cls pooling method
normalize embeddings
shape compare: (1, 768) (1, 768)
Text: 你好!jone!
Max Difference: 2.1047890186309814e-07

  

 

Java调用BGE onnx代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.onnxruntime.*;
 
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.nio.file.Paths;
import java.util.Map;
import java.util.HashMap;
import java.nio.LongBuffer;
 
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
 
import java.io.IOException;
 
 
public class App2 {
    private static String poolingMethod = "cls";
    private static boolean normalizeEmbeddings = true;
 
    private static long[] padArray(long[] array, int length) {
        long[] paddedArray = new long[length];
        System.arraycopy(array, 0, paddedArray, 0, Math.min(array.length, length));
        return paddedArray;
    }
 
    private static float[] encode(OrtEnvironment env, OrtSession session, HuggingFaceTokenizer tokenizer, String text) throws OrtException {
        Encoding enc = tokenizer.encode(text);
        long[] inputIdsData = enc.getIds();
        long[] attentionMaskData = enc.getAttentionMask();
 
 
        int maxLength = 128;
        int batchSize = 1;
 
        long[] inputIdsShape = new long[]{batchSize, maxLength};
        long[] attentionMaskShape = new long[]{batchSize, maxLength};
        // 确保数组长度为128
        inputIdsData = padArray(inputIdsData, maxLength);
        attentionMaskData = padArray(attentionMaskData, maxLength);
 
        OnnxTensor inputIdsTensor = OnnxTensor.createTensor(env, LongBuffer.wrap(inputIdsData), inputIdsShape);
        OnnxTensor attentionMaskTensor = OnnxTensor.createTensor(env, LongBuffer.wrap(attentionMaskData), attentionMaskShape);
 
        // 创建输入的Map
        Map<String, OnnxTensor> inputs = new HashMap<>();
        inputs.put("input_ids", inputIdsTensor);
        inputs.put("attention_mask", attentionMaskTensor);
 
        // 运行推理
        OrtSession.Result result = session.run(inputs);
 
        // 获取输出
//        float[][][] output = (float[][][]) result.get(0).getValue();
//        System.out.println("Output shape: [" + output.length + ", " + output[0].length + ", " + output[0][0].length + "]");
 
        // 提取三维数组
        float[][][] lastHiddenState = (float[][][]) result.get(0).getValue();
 
        float[] embeddings;
        if ("cls".equals(poolingMethod)) {
//            System.out.println("cls pooling method");
            embeddings = lastHiddenState[0][0];
        } else if ("mean".equals(poolingMethod)) {
//            System.out.println("mean pooling mode");
            int sequenceLength = lastHiddenState[0].length;
            int hiddenSize = lastHiddenState[0][0].length;
            float[] sum = new float[hiddenSize];
            int count = 0;
 
            for (int i = 0; i < sequenceLength; i++) {
                if (attentionMaskData[i] == 1) {
                    for (int j = 0; j < hiddenSize; j++) {
                        sum[j] += lastHiddenState[0][i][j];
                    }
                    count++;
                }
            }
 
            float[] mean = new float[hiddenSize];
            for (int j = 0; j < hiddenSize; j++) {
                mean[j] = sum[j] / count;
            }
            embeddings = mean;
        } else {
            throw new IllegalArgumentException("Unsupported pooling method: " + poolingMethod);
        }
 
        if (normalizeEmbeddings) {
//            System.out.println("normalize embeddings");
            float norm = 0;
            for (float v : embeddings) {
                norm += v * v;
            }
            norm = (float) Math.sqrt(norm);
            for (int i = 0; i < embeddings.length; i++) {
                embeddings[i] /= norm;
            }
        }
 
        // 释放资源
        inputIdsTensor.close();
        attentionMaskTensor.close();
 
        return embeddings;
    }
 
 
    public static void main(String[] args) throws OrtException, IOException {
        // 加载ONNX模型
        String modelPath = "D:\\source\\pythonProject\\flag_model.onnx";// "flag_model.onnx";
        OrtEnvironment env = OrtEnvironment.getEnvironment();
        OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
        OrtSession session = env.createSession(modelPath, opts);
 
        HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance(Paths.get("D:\\source\\pythonProject\\onnx\\tokenizer.json"));
 
        String[] texts = {
                "This is a sample text for embedding calculation.",
                "Another example text to test the model.",
                "Yet another text to ensure consistency.",
                "Testing with different lengths and contents.",
                "Final text to verify the ONNX model accuracy.",
                "中文数据测试",
                "随便说点什么吧!反正也只是测试用!",
                "你好!jone!"
        };
 
        try (BufferedWriter writer = new BufferedWriter(new FileWriter("D:\\data\\java_bge.txt"))) {
            for (String text : texts) {
                float[] clsEmbedding = encode(env, session, tokenizer, text);
 
                // writer.write("Text: " + text + "\n");
//                writer.write("CLS Embedding shape: " + clsEmbedding.length + "\n");
                writer.write("[");
                for (int i = 0; i < clsEmbedding.length; i++) {
                    writer.write(String.valueOf(clsEmbedding[i]));
                    if (i < clsEmbedding.length - 1) {
                        writer.write(", ");
                    }
                }
                writer.write("]\n");
            }
        }
 
        session.close();
        env.close();
    }
}

  

 

最后,我们比较下二者的输出差异:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def compare_java_python_result():
    import numpy as np
    import json
 
    # 函数:从文件读取嵌入向量
    def load_embeddings(filename):
        embeddings = []
        with open(filename, 'r') as f:
            for line in f:
                embedding = json.loads(line.strip())
                embeddings.append(embedding)
        return np.array(embeddings)
 
    # 加载两个文件的嵌入向量
    original_embeddings = load_embeddings(r"D:\data\java_bge.txt")
    onnx_embeddings = load_embeddings(r"D:\data\python_bge.txt")
 
    # 检查两个文件的嵌入向量数量是否匹配
    if original_embeddings.shape != onnx_embeddings.shape:
        raise ValueError("The number of embeddings in the two files does not match.")
 
    # 比较每对嵌入向量并计算最大差异
    for i in range(original_embeddings.shape[0]):
        original_embedding = original_embeddings[i]
        onnx_embedding = onnx_embeddings[i]
 
        # 计算差异
        difference = np.abs(original_embedding - onnx_embedding)
        max_difference = np.max(difference)
 
        # 输出结果
        print(f"Embedding {i + 1}:")
        print(f"Max Difference: {max_difference}")
        print("-" * 50)
 
 
compare_java_python_result()

  

输出结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
--------------------------------------------------
Embedding 1:
Max Difference: 7.499999999938112e-07
--------------------------------------------------
Embedding 2:
Max Difference: 5.600000000383076e-07
--------------------------------------------------
Embedding 3:
Max Difference: 1.8700000000565487e-07
--------------------------------------------------
Embedding 4:
Max Difference: 4.500000000406956e-07
--------------------------------------------------
Embedding 5:
Max Difference: 6.000000000172534e-07
--------------------------------------------------
Embedding 6:
Max Difference: 5.699999999775329e-07
--------------------------------------------------
Embedding 7:
Max Difference: 5.700000000885552e-07
--------------------------------------------------
Embedding 8:
Max Difference: 2.1039999999975662e-07
--------------------------------------------------

  

又是一个充满收获的一天!!!欧耶!!!

posted @   bonelee  阅读(572)  评论(1编辑  收藏  举报
相关博文:
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 记一次.NET内存居高不下排查解决与启示
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· DeepSeek 开源周回顾「GitHub 热点速览」
历史上的今天:
2021-06-29 用于对机器学习模型进行对抗性攻击、防御和基准测试的Python库:CleverHans 3.0.0
2021-06-29 对抗机器学习(Adversarial Machine Learning)发展现状 2018年文章
2018-06-29 DDOS工具合集---CC 2.0(僵尸网络proxy,单一url,可设置cookie,refer),传奇克星(代理+单一url,可设置cookie),NetBot_Attacker网络僵尸1.0(僵尸网络,HTTP NO-Cache Get攻击模式,CC攻击,HTTP空GET请求攻击),傀儡僵尸VIP1.4版(僵尸网络,动态单一url)、上兴网络僵尸2.3、中国制造网络僵尸、安全基地网络僵尸==
2017-06-29 DNS报文格式(RFC1035)
点击右上角即可分享
微信分享提示