Java通过共享内存调用机器学习python代码示例
背景:弱鸡的java不支持丰富的机器学习算法。
需求:python实现了一个bert分类,希望给java代码调用。因此,使用共享内存的方式实现跨进程调用。
在进程间通信(IPC)中,性能是一个重要的考虑因素。以下是几种常见的IPC方式及其性能比较:
1. Socket通信:
- 优点: 跨网络和本地通信都可以使用,灵活性高。
- 缺点: 相对较慢,因为涉及网络协议栈的开销。
- 适用场景: 分布式系统、跨主机通信。
2. 管道(Pipes):
- 优点: 简单、易用,适合父子进程间通信。
- 缺点: 只能用于单向通信,且仅限于同一主机。
- 适用场景: 父子进程间的简单数据传输。
3. 共享内存(Shared Memory):
- 优点: 速度最快,因为数据直接在内存中共享。
- 缺点: 需要同步机制(如信号量)来避免竞争条件,编程复杂度较高。
- 适用场景: 高性能需求的进程间通信。
4. 消息队列(Message Queues):
- 优点: 支持消息的有序传递和优先级。
- 缺点: 相对较慢,适合中小规模数据传输。
- 适用场景: 需要消息排队和优先级的场景。
5. 信号量(Semaphores):
- 优点: 用于进程间的同步和互斥。
- 缺点: 只适用于同步,不适合大数据传输。
- 适用场景: 进程间的同步和资源管理。
性能比较
- 共享内存通常是最快的,因为它避免了数据在内核和用户空间之间的拷贝,直接在内存中共享数据。
- 管道和消息队列的性能次之,因为它们需要在内核和用户空间之间进行数据拷贝。
- Socket通信的性能最慢,尤其是跨网络通信时,因为涉及网络协议栈的开销。
选择建议
- 如果你需要在同一主机上的进程间进行高性能通信,共享内存是最佳选择。
- 如果你需要简单的父子进程间通信,管道是一个不错的选择。
- 如果你需要跨主机通信,Socket是唯一的选择。
Java代码:
import java.io.RandomAccessFile; import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; import java.nio.charset.StandardCharsets; public class Main { private final static String FILE_PATH = "D:\\shared_memory.bin3"; private final static int FILE_SIZE = 1024*1024*1024 + 1; // 1MB for text, 1 byte for lock public static void main(String[] args) { try (RandomAccessFile memoryFile = new RandomAccessFile(FILE_PATH, "rw"); FileChannel fileChannel = memoryFile.getChannel()) { MappedByteBuffer buffer = fileChannel.map(FileChannel.MapMode.READ_WRITE, 0, FILE_SIZE); String textToClassify = "什么是sql注入?"; // Wait until the buffer is available for writing while (buffer.get(0) != 0) { System.out.println("Waiting for buffer to be available for writing..."); Thread.sleep(100); } // Write the text to the buffer buffer.put(0, (byte) 1); // Mark as written byte[] textBytes = textToClassify.getBytes(StandardCharsets.UTF_8); buffer.position(1); buffer.put(textBytes); // Pad the remaining part with 0x00 (null character in C, '\0' character) buffer.position(1 + textBytes.length); buffer.put(new byte[1024 - textBytes.length]); System.out.println("Written to memory: " + textToClassify); // Polling for result byte[] data = new byte[1024]; String classificationResult = ""; while (true) { if (buffer.get(0) == 0) { // Check if buffer is available for writing buffer.position(1); buffer.get(data); classificationResult = new String(data, StandardCharsets.UTF_8).trim(); System.out.println("Classification result received: " + classificationResult); break; } Thread.sleep(1); // Sleep for a short time before polling again } } catch (Exception e) { e.printStackTrace(); } } }
python代码:
# server.py import json import mmap import time import os # from bert_classify import predict # from sec_tool_rag import search_sectool_knowledge_base def classify_text(text): return 100 # predicted_label, prediction = predict(text) # return predicted_label FILE_PATH = 'D:\\shared_memory.bin3' FILE_SIZE = 1024*1024*1024 + 1 # 1MB bytes for text, 1 byte for lock if not os.path.exists(FILE_PATH): with open(FILE_PATH, 'w+b') as f: f.write(b'\x00' * FILE_SIZE) with open(FILE_PATH, 'r+b') as f: mm = mmap.mmap(f.fileno(), FILE_SIZE) while True: mm.seek(0) lock = mm.read_byte() if lock == 1: # Data written and ready to be read mm.seek(1) message = mm.read(1024).decode('utf-8').rstrip('\x00').strip() print(f"Received message: {message}") # tools = search_sectool_knowledge_base(message, topk=3) # print(f"Sec tools for '{message}': {tools}") label = classify_text(message) print(f"Classified label: {label}") # str2write = json.dumps(tools) str2write = str(label) # Write the classification result mm.seek(1) mm.write(str2write.encode().ljust(1024, b'\x00')) mm.seek(0) mm.write_byte(0) # Mark as available for writing print("Result written and memory marked as available for writing") time.sleep(0.001)
启动python程序,再启动java代码:
输出
Received message: 什么是sql注入? Classified label: 100 Result written and memory marked as available for writing ------------------------------------------ Written to memory: 什么是sql注入? Classification result received: 100