python BGE 模型转换为onnx给java调用

最近在做RAG,因为涉及embedding计算,用到了BAAI BGE小模型,但是模型是给python调用的,需要转换为onnx格式给java使用。所以有了下面的探索:

python代码:

import torch
from transformers import AutoTokenizer, AutoModel
from FlagEmbedding import FlagModel

# 初始化模型
model_name_or_path = 'bge-base-zh-v1.5/'
flag_model = FlagModel(model_name_or_path)
flag_model2 = FlagModel(model_name_or_path)

# 设置模型为评估模式
flag_model.model.eval()

# 创建一个dummy输入
dummy_input_text = "This is a sample text for embedding calculation."
embedding = flag_model2.encode(dummy_input_text)
print("embedding shape:", embedding.shape)

inputs = flag_model.tokenizer(dummy_input_text, return_tensors="pt", padding='max_length', truncation=True, max_length=128)

# 将输入移动到模型的设备
inputs = {k: v.to(flag_model.device) for k, v in inputs.items()}

# 导出模型为ONNX格式
onnx_model_path = "flag_model.onnx"
torch.onnx.export(
    flag_model.model,
    (inputs['input_ids'], inputs['attention_mask']),
    onnx_model_path,
    input_names=['input_ids', 'attention_mask'],
    output_names=['output'],
    dynamic_axes={'input_ids': {0: 'batch_size'}, 'attention_mask': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)

print(f"Model has been converted to ONNX and saved to {onnx_model_path}")

import torch
import onnxruntime as ort
import numpy as np
from transformers import AutoTokenizer
from FlagEmbedding import FlagModel

# 初始化模型和tokenizer
model_name_or_path = 'bge-base-zh-v1.5/'
flag_model = FlagModel(model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

# 设置模型为评估模式
flag_model.model.eval()

# ONNX模型推理
def onnx_inference(text, pooling_method='cls', normalize_embeddings=True):
    ort_session = ort.InferenceSession("flag_model.onnx")
    inputs = tokenizer(text, return_tensors="pt", padding='max_length', truncation=True, max_length=128)
    input_ids = inputs['input_ids'].cpu().numpy()
    attention_mask = inputs['attention_mask'].cpu().numpy()
    ort_inputs = {
        'input_ids': input_ids,
        'attention_mask': attention_mask
    }
    ort_outs = ort_session.run(None, ort_inputs)
    last_hidden_state = ort_outs[0]

    # Apply pooling
    if pooling_method == 'cls':
        print("cls pooling method")
        embeddings = last_hidden_state[:, 0, :]
    elif pooling_method == 'mean':
        print("mean pooling mode")
        s = np.sum(last_hidden_state * np.expand_dims(attention_mask, axis=-1), axis=1)
        d = np.sum(attention_mask, axis=1, keepdims=True)
        embeddings = s / d

    # Normalize embeddings if required
    if normalize_embeddings:
        print("normalize embeddings")
        norm = np.linalg.norm(embeddings, axis=-1, keepdims=True)
        embeddings = embeddings / norm

    return embeddings


# 输入文本
texts = [
    "This is a sample text for embedding calculation.",
    "Another example text to test the model.",
    "Yet another text to ensure consistency.",
    "Testing with different lengths and contents.",
    "Final text to verify the ONNX model accuracy.",
    "中文数据测试",
    "随便说点什么吧!反正也只是测试用!",
    "你好!jone!"
]
# 对比结果
for text in texts:
    # original_embedding = original_inference(text)
    original_embedding = flag_model2.encode(text).reshape(1, 768)
    onnx_embedding = onnx_inference(text) # .reshape(768,)
    print("shape compare:", original_embedding.shape, onnx_embedding.shape)
    difference = np.abs(original_embedding - onnx_embedding)
    max_difference = np.max(difference)
    print(f"Text: {text}")
    print(f"Max Difference: {max_difference}")
    print("-" * 50)


with open("D:\\data\\python_bge.txt", "w", encoding="utf-8") as f:
    for text in texts:
        cls_embedding = flag_model2.encode(text)
        f.write("[")
        f.write(", ".join(map(str, cls_embedding)))
        f.write("]\n")

  

python输出:

D:\Python\Python312\python.exe D:\source\pythonProject\demo_onnx.py 
embedding shape: (768,)
Model has been converted to ONNX and saved to flag_model.onnx
cls pooling method
normalize embeddings
shape compare: (1, 768) (1, 768)
Text: This is a sample text for embedding calculation.
Max Difference: 8.940696716308594e-07
--------------------------------------------------
cls pooling method
normalize embeddings
shape compare: (1, 768) (1, 768)
Text: Another example text to test the model.
Max Difference: 5.364418029785156e-07
--------------------------------------------------
cls pooling method
normalize embeddings
shape compare: (1, 768) (1, 768)
Text: Yet another text to ensure consistency.
Max Difference: 2.384185791015625e-07
--------------------------------------------------
cls pooling method
normalize embeddings
shape compare: (1, 768) (1, 768)
Text: Testing with different lengths and contents.
Max Difference: 5.960464477539062e-07
--------------------------------------------------
cls pooling method
normalize embeddings
shape compare: (1, 768) (1, 768)
Text: Final text to verify the ONNX model accuracy.
Max Difference: 4.76837158203125e-07
--------------------------------------------------
cls pooling method
normalize embeddings
shape compare: (1, 768) (1, 768)
Text: 中文数据测试
Max Difference: 3.5762786865234375e-07
--------------------------------------------------
cls pooling method
normalize embeddings
shape compare: (1, 768) (1, 768)
Text: 随便说点什么吧!反正也只是测试用!
Max Difference: 5.364418029785156e-07
--------------------------------------------------
cls pooling method
normalize embeddings
shape compare: (1, 768) (1, 768)
Text: 你好!jone!
Max Difference: 2.1047890186309814e-07

  

 

Java调用BGE onnx代码:

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.onnxruntime.*;

import java.io.BufferedWriter;
import java.io.FileWriter;
import java.nio.file.Paths;
import java.util.Map;
import java.util.HashMap;
import java.nio.LongBuffer;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;

import java.io.IOException;


public class App2 {
    private static String poolingMethod = "cls";
    private static boolean normalizeEmbeddings = true;

    private static long[] padArray(long[] array, int length) {
        long[] paddedArray = new long[length];
        System.arraycopy(array, 0, paddedArray, 0, Math.min(array.length, length));
        return paddedArray;
    }

    private static float[] encode(OrtEnvironment env, OrtSession session, HuggingFaceTokenizer tokenizer, String text) throws OrtException {
        Encoding enc = tokenizer.encode(text);
        long[] inputIdsData = enc.getIds();
        long[] attentionMaskData = enc.getAttentionMask();


        int maxLength = 128;
        int batchSize = 1;

        long[] inputIdsShape = new long[]{batchSize, maxLength};
        long[] attentionMaskShape = new long[]{batchSize, maxLength};
        // 确保数组长度为128
        inputIdsData = padArray(inputIdsData, maxLength);
        attentionMaskData = padArray(attentionMaskData, maxLength);

        OnnxTensor inputIdsTensor = OnnxTensor.createTensor(env, LongBuffer.wrap(inputIdsData), inputIdsShape);
        OnnxTensor attentionMaskTensor = OnnxTensor.createTensor(env, LongBuffer.wrap(attentionMaskData), attentionMaskShape);

        // 创建输入的Map
        Map<String, OnnxTensor> inputs = new HashMap<>();
        inputs.put("input_ids", inputIdsTensor);
        inputs.put("attention_mask", attentionMaskTensor);

        // 运行推理
        OrtSession.Result result = session.run(inputs);

        // 获取输出
//        float[][][] output = (float[][][]) result.get(0).getValue();
//        System.out.println("Output shape: [" + output.length + ", " + output[0].length + ", " + output[0][0].length + "]");

        // 提取三维数组
        float[][][] lastHiddenState = (float[][][]) result.get(0).getValue();

        float[] embeddings;
        if ("cls".equals(poolingMethod)) {
//            System.out.println("cls pooling method");
            embeddings = lastHiddenState[0][0];
        } else if ("mean".equals(poolingMethod)) {
//            System.out.println("mean pooling mode");
            int sequenceLength = lastHiddenState[0].length;
            int hiddenSize = lastHiddenState[0][0].length;
            float[] sum = new float[hiddenSize];
            int count = 0;

            for (int i = 0; i < sequenceLength; i++) {
                if (attentionMaskData[i] == 1) {
                    for (int j = 0; j < hiddenSize; j++) {
                        sum[j] += lastHiddenState[0][i][j];
                    }
                    count++;
                }
            }

            float[] mean = new float[hiddenSize];
            for (int j = 0; j < hiddenSize; j++) {
                mean[j] = sum[j] / count;
            }
            embeddings = mean;
        } else {
            throw new IllegalArgumentException("Unsupported pooling method: " + poolingMethod);
        }

        if (normalizeEmbeddings) {
//            System.out.println("normalize embeddings");
            float norm = 0;
            for (float v : embeddings) {
                norm += v * v;
            }
            norm = (float) Math.sqrt(norm);
            for (int i = 0; i < embeddings.length; i++) {
                embeddings[i] /= norm;
            }
        }

        // 释放资源
        inputIdsTensor.close();
        attentionMaskTensor.close();

        return embeddings;
    }


    public static void main(String[] args) throws OrtException, IOException {
        // 加载ONNX模型
        String modelPath = "D:\\source\\pythonProject\\flag_model.onnx";// "flag_model.onnx";
        OrtEnvironment env = OrtEnvironment.getEnvironment();
        OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
        OrtSession session = env.createSession(modelPath, opts);

        HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance(Paths.get("D:\\source\\pythonProject\\onnx\\tokenizer.json"));

        String[] texts = {
                "This is a sample text for embedding calculation.",
                "Another example text to test the model.",
                "Yet another text to ensure consistency.",
                "Testing with different lengths and contents.",
                "Final text to verify the ONNX model accuracy.",
                "中文数据测试",
                "随便说点什么吧!反正也只是测试用!",
                "你好!jone!"
        };

        try (BufferedWriter writer = new BufferedWriter(new FileWriter("D:\\data\\java_bge.txt"))) {
            for (String text : texts) {
                float[] clsEmbedding = encode(env, session, tokenizer, text);

                // writer.write("Text: " + text + "\n");
//                writer.write("CLS Embedding shape: " + clsEmbedding.length + "\n");
                writer.write("[");
                for (int i = 0; i < clsEmbedding.length; i++) {
                    writer.write(String.valueOf(clsEmbedding[i]));
                    if (i < clsEmbedding.length - 1) {
                        writer.write(", ");
                    }
                }
                writer.write("]\n");
            }
        }

        session.close();
        env.close();
    }
}

  

 

最后,我们比较下二者的输出差异:

def compare_java_python_result():
    import numpy as np
    import json

    # 函数:从文件读取嵌入向量
    def load_embeddings(filename):
        embeddings = []
        with open(filename, 'r') as f:
            for line in f:
                embedding = json.loads(line.strip())
                embeddings.append(embedding)
        return np.array(embeddings)

    # 加载两个文件的嵌入向量
    original_embeddings = load_embeddings(r"D:\data\java_bge.txt")
    onnx_embeddings = load_embeddings(r"D:\data\python_bge.txt")

    # 检查两个文件的嵌入向量数量是否匹配
    if original_embeddings.shape != onnx_embeddings.shape:
        raise ValueError("The number of embeddings in the two files does not match.")

    # 比较每对嵌入向量并计算最大差异
    for i in range(original_embeddings.shape[0]):
        original_embedding = original_embeddings[i]
        onnx_embedding = onnx_embeddings[i]

        # 计算差异
        difference = np.abs(original_embedding - onnx_embedding)
        max_difference = np.max(difference)

        # 输出结果
        print(f"Embedding {i + 1}:")
        print(f"Max Difference: {max_difference}")
        print("-" * 50)


compare_java_python_result()

  

输出结果:

--------------------------------------------------
Embedding 1:
Max Difference: 7.499999999938112e-07
--------------------------------------------------
Embedding 2:
Max Difference: 5.600000000383076e-07
--------------------------------------------------
Embedding 3:
Max Difference: 1.8700000000565487e-07
--------------------------------------------------
Embedding 4:
Max Difference: 4.500000000406956e-07
--------------------------------------------------
Embedding 5:
Max Difference: 6.000000000172534e-07
--------------------------------------------------
Embedding 6:
Max Difference: 5.699999999775329e-07
--------------------------------------------------
Embedding 7:
Max Difference: 5.700000000885552e-07
--------------------------------------------------
Embedding 8:
Max Difference: 2.1039999999975662e-07
--------------------------------------------------

  

又是一个充满收获的一天!!!欧耶!!!

posted @ 2024-06-29 11:39  bonelee  阅读(380)  评论(1编辑  收藏  举报