python BGE 模型转换为onnx给java调用
最近在做RAG,因为涉及embedding计算,用到了BAAI BGE小模型,但是模型是给python调用的,需要转换为onnx格式给java使用。所以有了下面的探索:
python代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 | import torch from transformers import AutoTokenizer, AutoModel from FlagEmbedding import FlagModel # 初始化模型 model_name_or_path = 'bge-base-zh-v1.5/' flag_model = FlagModel(model_name_or_path) flag_model2 = FlagModel(model_name_or_path) # 设置模型为评估模式 flag_model.model. eval () # 创建一个dummy输入 dummy_input_text = "This is a sample text for embedding calculation." embedding = flag_model2.encode(dummy_input_text) print ( "embedding shape:" , embedding.shape) inputs = flag_model.tokenizer(dummy_input_text, return_tensors = "pt" , padding = 'max_length' , truncation = True , max_length = 128 ) # 将输入移动到模型的设备 inputs = {k: v.to(flag_model.device) for k, v in inputs.items()} # 导出模型为ONNX格式 onnx_model_path = "flag_model.onnx" torch.onnx.export( flag_model.model, (inputs[ 'input_ids' ], inputs[ 'attention_mask' ]), onnx_model_path, input_names = [ 'input_ids' , 'attention_mask' ], output_names = [ 'output' ], dynamic_axes = { 'input_ids' : { 0 : 'batch_size' }, 'attention_mask' : { 0 : 'batch_size' }, 'output' : { 0 : 'batch_size' }} ) print (f "Model has been converted to ONNX and saved to {onnx_model_path}" ) import torch import onnxruntime as ort import numpy as np from transformers import AutoTokenizer from FlagEmbedding import FlagModel # 初始化模型和tokenizer model_name_or_path = 'bge-base-zh-v1.5/' flag_model = FlagModel(model_name_or_path) tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) # 设置模型为评估模式 flag_model.model. eval () # ONNX模型推理 def onnx_inference(text, pooling_method = 'cls' , normalize_embeddings = True ): ort_session = ort.InferenceSession( "flag_model.onnx" ) inputs = tokenizer(text, return_tensors = "pt" , padding = 'max_length' , truncation = True , max_length = 128 ) input_ids = inputs[ 'input_ids' ].cpu().numpy() attention_mask = inputs[ 'attention_mask' ].cpu().numpy() ort_inputs = { 'input_ids' : input_ids, 'attention_mask' : attention_mask } ort_outs = ort_session.run( None , ort_inputs) last_hidden_state = ort_outs[ 0 ] # Apply pooling if pooling_method = = 'cls' : print ( "cls pooling method" ) embeddings = last_hidden_state[:, 0 , :] elif pooling_method = = 'mean' : print ( "mean pooling mode" ) s = np. sum (last_hidden_state * np.expand_dims(attention_mask, axis = - 1 ), axis = 1 ) d = np. sum (attention_mask, axis = 1 , keepdims = True ) embeddings = s / d # Normalize embeddings if required if normalize_embeddings: print ( "normalize embeddings" ) norm = np.linalg.norm(embeddings, axis = - 1 , keepdims = True ) embeddings = embeddings / norm return embeddings # 输入文本 texts = [ "This is a sample text for embedding calculation." , "Another example text to test the model." , "Yet another text to ensure consistency." , "Testing with different lengths and contents." , "Final text to verify the ONNX model accuracy." , "中文数据测试" , "随便说点什么吧!反正也只是测试用!" , "你好!jone!" ] # 对比结果 for text in texts: # original_embedding = original_inference(text) original_embedding = flag_model2.encode(text).reshape( 1 , 768 ) onnx_embedding = onnx_inference(text) # .reshape(768,) print ( "shape compare:" , original_embedding.shape, onnx_embedding.shape) difference = np. abs (original_embedding - onnx_embedding) max_difference = np. max (difference) print (f "Text: {text}" ) print (f "Max Difference: {max_difference}" ) print ( "-" * 50 ) with open ( "D:\\data\\python_bge.txt" , "w" , encoding = "utf-8" ) as f: for text in texts: cls_embedding = flag_model2.encode(text) f.write( "[" ) f.write( ", " .join( map ( str , cls_embedding))) f.write( "]\n" ) |
python输出:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 | D:\Python\Python312\python.exe D:\source\pythonProject\demo_onnx.py embedding shape: ( 768 ,) Model has been converted to ONNX and saved to flag_model.onnx cls pooling method normalize embeddings shape compare: ( 1 , 768 ) ( 1 , 768 ) Text: This is a sample text for embedding calculation. Max Difference: 8.940696716308594e - 07 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - cls pooling method normalize embeddings shape compare: ( 1 , 768 ) ( 1 , 768 ) Text: Another example text to test the model. Max Difference: 5.364418029785156e - 07 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - cls pooling method normalize embeddings shape compare: ( 1 , 768 ) ( 1 , 768 ) Text: Yet another text to ensure consistency. Max Difference: 2.384185791015625e - 07 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - cls pooling method normalize embeddings shape compare: ( 1 , 768 ) ( 1 , 768 ) Text: Testing with different lengths and contents. Max Difference: 5.960464477539062e - 07 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - cls pooling method normalize embeddings shape compare: ( 1 , 768 ) ( 1 , 768 ) Text: Final text to verify the ONNX model accuracy. Max Difference: 4.76837158203125e - 07 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - cls pooling method normalize embeddings shape compare: ( 1 , 768 ) ( 1 , 768 ) Text: 中文数据测试 Max Difference: 3.5762786865234375e - 07 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - cls pooling method normalize embeddings shape compare: ( 1 , 768 ) ( 1 , 768 ) Text: 随便说点什么吧!反正也只是测试用! Max Difference: 5.364418029785156e - 07 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - cls pooling method normalize embeddings shape compare: ( 1 , 768 ) ( 1 , 768 ) Text: 你好!jone! Max Difference: 2.1047890186309814e - 07 |
Java调用BGE onnx代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | import ai.djl.huggingface.tokenizers.Encoding; import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; import ai.onnxruntime. * ; import java.io.BufferedWriter; import java.io.FileWriter; import java.nio. file .Paths; import java.util. Map ; import java.util.HashMap; import java.nio.LongBuffer; import ai.onnxruntime.OnnxTensor; import ai.onnxruntime.OrtEnvironment; import ai.onnxruntime.OrtException; import ai.onnxruntime.OrtSession; import java.io.IOException; public class App2 { private static String poolingMethod = "cls" ; private static boolean normalizeEmbeddings = true; private static long [] padArray( long [] array, int length) { long [] paddedArray = new long [length]; System.arraycopy(array, 0 , paddedArray, 0 , Math. min (array.length, length)); return paddedArray; } private static float [] encode(OrtEnvironment env, OrtSession session, HuggingFaceTokenizer tokenizer, String text) throws OrtException { Encoding enc = tokenizer.encode(text); long [] inputIdsData = enc.getIds(); long [] attentionMaskData = enc.getAttentionMask(); int maxLength = 128 ; int batchSize = 1 ; long [] inputIdsShape = new long []{batchSize, maxLength}; long [] attentionMaskShape = new long []{batchSize, maxLength}; / / 确保数组长度为 128 inputIdsData = padArray(inputIdsData, maxLength); attentionMaskData = padArray(attentionMaskData, maxLength); OnnxTensor inputIdsTensor = OnnxTensor.createTensor(env, LongBuffer.wrap(inputIdsData), inputIdsShape); OnnxTensor attentionMaskTensor = OnnxTensor.createTensor(env, LongBuffer.wrap(attentionMaskData), attentionMaskShape); / / 创建输入的 Map Map <String, OnnxTensor> inputs = new HashMap<>(); inputs.put( "input_ids" , inputIdsTensor); inputs.put( "attention_mask" , attentionMaskTensor); / / 运行推理 OrtSession.Result result = session.run(inputs); / / 获取输出 / / float [][][] output = ( float [][][]) result.get( 0 ).getValue(); / / System.out.println( "Output shape: [" + output.length + ", " + output[ 0 ].length + ", " + output[ 0 ][ 0 ].length + "]" ); / / 提取三维数组 float [][][] lastHiddenState = ( float [][][]) result.get( 0 ).getValue(); float [] embeddings; if ( "cls" .equals(poolingMethod)) { / / System.out.println( "cls pooling method" ); embeddings = lastHiddenState[ 0 ][ 0 ]; } else if ( "mean" .equals(poolingMethod)) { / / System.out.println( "mean pooling mode" ); int sequenceLength = lastHiddenState[ 0 ].length; int hiddenSize = lastHiddenState[ 0 ][ 0 ].length; float [] sum = new float [hiddenSize]; int count = 0 ; for ( int i = 0 ; i < sequenceLength; i + + ) { if (attentionMaskData[i] = = 1 ) { for ( int j = 0 ; j < hiddenSize; j + + ) { sum [j] + = lastHiddenState[ 0 ][i][j]; } count + + ; } } float [] mean = new float [hiddenSize]; for ( int j = 0 ; j < hiddenSize; j + + ) { mean[j] = sum [j] / count; } embeddings = mean; } else { throw new IllegalArgumentException( "Unsupported pooling method: " + poolingMethod); } if (normalizeEmbeddings) { / / System.out.println( "normalize embeddings" ); float norm = 0 ; for ( float v : embeddings) { norm + = v * v; } norm = ( float ) Math.sqrt(norm); for ( int i = 0 ; i < embeddings.length; i + + ) { embeddings[i] / = norm; } } / / 释放资源 inputIdsTensor.close(); attentionMaskTensor.close(); return embeddings; } public static void main(String[] args) throws OrtException, IOException { / / 加载ONNX模型 String modelPath = "D:\\source\\pythonProject\\flag_model.onnx" ; / / "flag_model.onnx" ; OrtEnvironment env = OrtEnvironment.getEnvironment(); OrtSession.SessionOptions opts = new OrtSession.SessionOptions(); OrtSession session = env.createSession(modelPath, opts); HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance(Paths.get( "D:\\source\\pythonProject\\onnx\\tokenizer.json" )); String[] texts = { "This is a sample text for embedding calculation." , "Another example text to test the model." , "Yet another text to ensure consistency." , "Testing with different lengths and contents." , "Final text to verify the ONNX model accuracy." , "中文数据测试" , "随便说点什么吧!反正也只是测试用!" , "你好!jone!" }; try (BufferedWriter writer = new BufferedWriter(new FileWriter( "D:\\data\\java_bge.txt" ))) { for (String text : texts) { float [] clsEmbedding = encode(env, session, tokenizer, text); / / writer.write( "Text: " + text + "\n" ); / / writer.write( "CLS Embedding shape: " + clsEmbedding.length + "\n" ); writer.write( "[" ); for ( int i = 0 ; i < clsEmbedding.length; i + + ) { writer.write(String.valueOf(clsEmbedding[i])); if (i < clsEmbedding.length - 1 ) { writer.write( ", " ); } } writer.write( "]\n" ); } } session.close(); env.close(); } } |
最后,我们比较下二者的输出差异:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | def compare_java_python_result(): import numpy as np import json # 函数:从文件读取嵌入向量 def load_embeddings(filename): embeddings = [] with open (filename, 'r' ) as f: for line in f: embedding = json.loads(line.strip()) embeddings.append(embedding) return np.array(embeddings) # 加载两个文件的嵌入向量 original_embeddings = load_embeddings(r "D:\data\java_bge.txt" ) onnx_embeddings = load_embeddings(r "D:\data\python_bge.txt" ) # 检查两个文件的嵌入向量数量是否匹配 if original_embeddings.shape ! = onnx_embeddings.shape: raise ValueError( "The number of embeddings in the two files does not match." ) # 比较每对嵌入向量并计算最大差异 for i in range (original_embeddings.shape[ 0 ]): original_embedding = original_embeddings[i] onnx_embedding = onnx_embeddings[i] # 计算差异 difference = np. abs (original_embedding - onnx_embedding) max_difference = np. max (difference) # 输出结果 print (f "Embedding {i + 1}:" ) print (f "Max Difference: {max_difference}" ) print ( "-" * 50 ) compare_java_python_result() |
输出结果:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - Embedding 1 : Max Difference: 7.499999999938112e - 07 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - Embedding 2 : Max Difference: 5.600000000383076e - 07 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - Embedding 3 : Max Difference: 1.8700000000565487e - 07 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - Embedding 4 : Max Difference: 4.500000000406956e - 07 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - Embedding 5 : Max Difference: 6.000000000172534e - 07 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - Embedding 6 : Max Difference: 5.699999999775329e - 07 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - Embedding 7 : Max Difference: 5.700000000885552e - 07 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - Embedding 8 : Max Difference: 2.1039999999975662e - 07 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - |
又是一个充满收获的一天!!!欧耶!!!
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步