python BGE 模型转换为onnx给java调用
最近在做RAG,因为涉及embedding计算,用到了BAAI BGE小模型,但是模型是给python调用的,需要转换为onnx格式给java使用。所以有了下面的探索:
python代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 | import torch from transformers import AutoTokenizer, AutoModel from FlagEmbedding import FlagModel # 初始化模型 model_name_or_path = 'bge-base-zh-v1.5/' flag_model = FlagModel(model_name_or_path) flag_model2 = FlagModel(model_name_or_path) # 设置模型为评估模式 flag_model.model. eval () # 创建一个dummy输入 dummy_input_text = "This is a sample text for embedding calculation." embedding = flag_model2.encode(dummy_input_text) print ( "embedding shape:" , embedding.shape) inputs = flag_model.tokenizer(dummy_input_text, return_tensors = "pt" , padding = 'max_length' , truncation = True , max_length = 128 ) # 将输入移动到模型的设备 inputs = {k: v.to(flag_model.device) for k, v in inputs.items()} # 导出模型为ONNX格式 onnx_model_path = "flag_model.onnx" torch.onnx.export( flag_model.model, (inputs[ 'input_ids' ], inputs[ 'attention_mask' ]), onnx_model_path, input_names = [ 'input_ids' , 'attention_mask' ], output_names = [ 'output' ], dynamic_axes = { 'input_ids' : { 0 : 'batch_size' }, 'attention_mask' : { 0 : 'batch_size' }, 'output' : { 0 : 'batch_size' }} ) print (f "Model has been converted to ONNX and saved to {onnx_model_path}" ) import torch import onnxruntime as ort import numpy as np from transformers import AutoTokenizer from FlagEmbedding import FlagModel # 初始化模型和tokenizer model_name_or_path = 'bge-base-zh-v1.5/' flag_model = FlagModel(model_name_or_path) tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) # 设置模型为评估模式 flag_model.model. eval () # ONNX模型推理 def onnx_inference(text, pooling_method = 'cls' , normalize_embeddings = True ): ort_session = ort.InferenceSession( "flag_model.onnx" ) inputs = tokenizer(text, return_tensors = "pt" , padding = 'max_length' , truncation = True , max_length = 128 ) input_ids = inputs[ 'input_ids' ].cpu().numpy() attention_mask = inputs[ 'attention_mask' ].cpu().numpy() ort_inputs = { 'input_ids' : input_ids, 'attention_mask' : attention_mask } ort_outs = ort_session.run( None , ort_inputs) last_hidden_state = ort_outs[ 0 ] # Apply pooling if pooling_method = = 'cls' : print ( "cls pooling method" ) embeddings = last_hidden_state[:, 0 , :] elif pooling_method = = 'mean' : print ( "mean pooling mode" ) s = np. sum (last_hidden_state * np.expand_dims(attention_mask, axis = - 1 ), axis = 1 ) d = np. sum (attention_mask, axis = 1 , keepdims = True ) embeddings = s / d # Normalize embeddings if required if normalize_embeddings: print ( "normalize embeddings" ) norm = np.linalg.norm(embeddings, axis = - 1 , keepdims = True ) embeddings = embeddings / norm return embeddings # 输入文本 texts = [ "This is a sample text for embedding calculation." , "Another example text to test the model." , "Yet another text to ensure consistency." , "Testing with different lengths and contents." , "Final text to verify the ONNX model accuracy." , "中文数据测试" , "随便说点什么吧!反正也只是测试用!" , "你好!jone!" ] # 对比结果 for text in texts: # original_embedding = original_inference(text) original_embedding = flag_model2.encode(text).reshape( 1 , 768 ) onnx_embedding = onnx_inference(text) # .reshape(768,) print ( "shape compare:" , original_embedding.shape, onnx_embedding.shape) difference = np. abs (original_embedding - onnx_embedding) max_difference = np. max (difference) print (f "Text: {text}" ) print (f "Max Difference: {max_difference}" ) print ( "-" * 50 ) with open ( "D:\\data\\python_bge.txt" , "w" , encoding = "utf-8" ) as f: for text in texts: cls_embedding = flag_model2.encode(text) f.write( "[" ) f.write( ", " .join( map ( str , cls_embedding))) f.write( "]\n" ) |
python输出:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 | D:\Python\Python312\python.exe D:\source\pythonProject\demo_onnx.py embedding shape: ( 768 ,) Model has been converted to ONNX and saved to flag_model.onnx cls pooling method normalize embeddings shape compare: ( 1 , 768 ) ( 1 , 768 ) Text: This is a sample text for embedding calculation. Max Difference: 8.940696716308594e - 07 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - cls pooling method normalize embeddings shape compare: ( 1 , 768 ) ( 1 , 768 ) Text: Another example text to test the model. Max Difference: 5.364418029785156e - 07 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - cls pooling method normalize embeddings shape compare: ( 1 , 768 ) ( 1 , 768 ) Text: Yet another text to ensure consistency. Max Difference: 2.384185791015625e - 07 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - cls pooling method normalize embeddings shape compare: ( 1 , 768 ) ( 1 , 768 ) Text: Testing with different lengths and contents. Max Difference: 5.960464477539062e - 07 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - cls pooling method normalize embeddings shape compare: ( 1 , 768 ) ( 1 , 768 ) Text: Final text to verify the ONNX model accuracy. Max Difference: 4.76837158203125e - 07 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - cls pooling method normalize embeddings shape compare: ( 1 , 768 ) ( 1 , 768 ) Text: 中文数据测试 Max Difference: 3.5762786865234375e - 07 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - cls pooling method normalize embeddings shape compare: ( 1 , 768 ) ( 1 , 768 ) Text: 随便说点什么吧!反正也只是测试用! Max Difference: 5.364418029785156e - 07 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - cls pooling method normalize embeddings shape compare: ( 1 , 768 ) ( 1 , 768 ) Text: 你好!jone! Max Difference: 2.1047890186309814e - 07 |
Java调用BGE onnx代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | import ai.djl.huggingface.tokenizers.Encoding; import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; import ai.onnxruntime. * ; import java.io.BufferedWriter; import java.io.FileWriter; import java.nio. file .Paths; import java.util. Map ; import java.util.HashMap; import java.nio.LongBuffer; import ai.onnxruntime.OnnxTensor; import ai.onnxruntime.OrtEnvironment; import ai.onnxruntime.OrtException; import ai.onnxruntime.OrtSession; import java.io.IOException; public class App2 { private static String poolingMethod = "cls" ; private static boolean normalizeEmbeddings = true; private static long [] padArray( long [] array, int length) { long [] paddedArray = new long [length]; System.arraycopy(array, 0 , paddedArray, 0 , Math. min (array.length, length)); return paddedArray; } private static float [] encode(OrtEnvironment env, OrtSession session, HuggingFaceTokenizer tokenizer, String text) throws OrtException { Encoding enc = tokenizer.encode(text); long [] inputIdsData = enc.getIds(); long [] attentionMaskData = enc.getAttentionMask(); int maxLength = 128 ; int batchSize = 1 ; long [] inputIdsShape = new long []{batchSize, maxLength}; long [] attentionMaskShape = new long []{batchSize, maxLength}; / / 确保数组长度为 128 inputIdsData = padArray(inputIdsData, maxLength); attentionMaskData = padArray(attentionMaskData, maxLength); OnnxTensor inputIdsTensor = OnnxTensor.createTensor(env, LongBuffer.wrap(inputIdsData), inputIdsShape); OnnxTensor attentionMaskTensor = OnnxTensor.createTensor(env, LongBuffer.wrap(attentionMaskData), attentionMaskShape); / / 创建输入的 Map Map <String, OnnxTensor> inputs = new HashMap<>(); inputs.put( "input_ids" , inputIdsTensor); inputs.put( "attention_mask" , attentionMaskTensor); / / 运行推理 OrtSession.Result result = session.run(inputs); / / 获取输出 / / float [][][] output = ( float [][][]) result.get( 0 ).getValue(); / / System.out.println( "Output shape: [" + output.length + ", " + output[ 0 ].length + ", " + output[ 0 ][ 0 ].length + "]" ); / / 提取三维数组 float [][][] lastHiddenState = ( float [][][]) result.get( 0 ).getValue(); float [] embeddings; if ( "cls" .equals(poolingMethod)) { / / System.out.println( "cls pooling method" ); embeddings = lastHiddenState[ 0 ][ 0 ]; } else if ( "mean" .equals(poolingMethod)) { / / System.out.println( "mean pooling mode" ); int sequenceLength = lastHiddenState[ 0 ].length; int hiddenSize = lastHiddenState[ 0 ][ 0 ].length; float [] sum = new float [hiddenSize]; int count = 0 ; for ( int i = 0 ; i < sequenceLength; i + + ) { if (attentionMaskData[i] = = 1 ) { for ( int j = 0 ; j < hiddenSize; j + + ) { sum [j] + = lastHiddenState[ 0 ][i][j]; } count + + ; } } float [] mean = new float [hiddenSize]; for ( int j = 0 ; j < hiddenSize; j + + ) { mean[j] = sum [j] / count; } embeddings = mean; } else { throw new IllegalArgumentException( "Unsupported pooling method: " + poolingMethod); } if (normalizeEmbeddings) { / / System.out.println( "normalize embeddings" ); float norm = 0 ; for ( float v : embeddings) { norm + = v * v; } norm = ( float ) Math.sqrt(norm); for ( int i = 0 ; i < embeddings.length; i + + ) { embeddings[i] / = norm; } } / / 释放资源 inputIdsTensor.close(); attentionMaskTensor.close(); return embeddings; } public static void main(String[] args) throws OrtException, IOException { / / 加载ONNX模型 String modelPath = "D:\\source\\pythonProject\\flag_model.onnx" ; / / "flag_model.onnx" ; OrtEnvironment env = OrtEnvironment.getEnvironment(); OrtSession.SessionOptions opts = new OrtSession.SessionOptions(); OrtSession session = env.createSession(modelPath, opts); HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance(Paths.get( "D:\\source\\pythonProject\\onnx\\tokenizer.json" )); String[] texts = { "This is a sample text for embedding calculation." , "Another example text to test the model." , "Yet another text to ensure consistency." , "Testing with different lengths and contents." , "Final text to verify the ONNX model accuracy." , "中文数据测试" , "随便说点什么吧!反正也只是测试用!" , "你好!jone!" }; try (BufferedWriter writer = new BufferedWriter(new FileWriter( "D:\\data\\java_bge.txt" ))) { for (String text : texts) { float [] clsEmbedding = encode(env, session, tokenizer, text); / / writer.write( "Text: " + text + "\n" ); / / writer.write( "CLS Embedding shape: " + clsEmbedding.length + "\n" ); writer.write( "[" ); for ( int i = 0 ; i < clsEmbedding.length; i + + ) { writer.write(String.valueOf(clsEmbedding[i])); if (i < clsEmbedding.length - 1 ) { writer.write( ", " ); } } writer.write( "]\n" ); } } session.close(); env.close(); } } |
最后,我们比较下二者的输出差异:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | def compare_java_python_result(): import numpy as np import json # 函数:从文件读取嵌入向量 def load_embeddings(filename): embeddings = [] with open (filename, 'r' ) as f: for line in f: embedding = json.loads(line.strip()) embeddings.append(embedding) return np.array(embeddings) # 加载两个文件的嵌入向量 original_embeddings = load_embeddings(r "D:\data\java_bge.txt" ) onnx_embeddings = load_embeddings(r "D:\data\python_bge.txt" ) # 检查两个文件的嵌入向量数量是否匹配 if original_embeddings.shape ! = onnx_embeddings.shape: raise ValueError( "The number of embeddings in the two files does not match." ) # 比较每对嵌入向量并计算最大差异 for i in range (original_embeddings.shape[ 0 ]): original_embedding = original_embeddings[i] onnx_embedding = onnx_embeddings[i] # 计算差异 difference = np. abs (original_embedding - onnx_embedding) max_difference = np. max (difference) # 输出结果 print (f "Embedding {i + 1}:" ) print (f "Max Difference: {max_difference}" ) print ( "-" * 50 ) compare_java_python_result() |
输出结果:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - Embedding 1 : Max Difference: 7.499999999938112e - 07 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - Embedding 2 : Max Difference: 5.600000000383076e - 07 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - Embedding 3 : Max Difference: 1.8700000000565487e - 07 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - Embedding 4 : Max Difference: 4.500000000406956e - 07 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - Embedding 5 : Max Difference: 6.000000000172534e - 07 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - Embedding 6 : Max Difference: 5.699999999775329e - 07 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - Embedding 7 : Max Difference: 5.700000000885552e - 07 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - Embedding 8 : Max Difference: 2.1039999999975662e - 07 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - |
又是一个充满收获的一天!!!欧耶!!!
标签:
机器学习
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 记一次.NET内存居高不下排查解决与启示
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· DeepSeek 开源周回顾「GitHub 热点速览」
2021-06-29 用于对机器学习模型进行对抗性攻击、防御和基准测试的Python库:CleverHans 3.0.0
2021-06-29 对抗机器学习(Adversarial Machine Learning)发展现状 2018年文章
2018-06-29 DDOS工具合集---CC 2.0(僵尸网络proxy,单一url,可设置cookie,refer),传奇克星(代理+单一url,可设置cookie),NetBot_Attacker网络僵尸1.0(僵尸网络,HTTP NO-Cache Get攻击模式,CC攻击,HTTP空GET请求攻击),傀儡僵尸VIP1.4版(僵尸网络,动态单一url)、上兴网络僵尸2.3、中国制造网络僵尸、安全基地网络僵尸==
2017-06-29 DNS报文格式(RFC1035)