Java通过共享内存调用机器学习python代码示例

背景:弱鸡的java不支持丰富的机器学习算法。

需求:python实现了一个bert分类,希望给java代码调用。因此,使用共享内存的方式实现跨进程调用。

在进程间通信(IPC)中,性能是一个重要的考虑因素。以下是几种常见的IPC方式及其性能比较:

1. Socket通信:

  • 优点: 跨网络和本地通信都可以使用,灵活性高。
  • 缺点: 相对较慢,因为涉及网络协议栈的开销。
  • 适用场景: 分布式系统、跨主机通信。

2. 管道(Pipes):

  • 优点: 简单、易用,适合父子进程间通信。
  • 缺点: 只能用于单向通信,且仅限于同一主机。
  • 适用场景: 父子进程间的简单数据传输。

3. 共享内存(Shared Memory):

  • 优点: 速度最快,因为数据直接在内存中共享。
  • 缺点: 需要同步机制(如信号量)来避免竞争条件,编程复杂度较高。
  • 适用场景: 高性能需求的进程间通信。

4. 消息队列(Message Queues):

  • 优点: 支持消息的有序传递和优先级。
  • 缺点: 相对较慢,适合中小规模数据传输。
  • 适用场景: 需要消息排队和优先级的场景。

5. 信号量(Semaphores):

  • 优点: 用于进程间的同步和互斥。
  • 缺点: 只适用于同步,不适合大数据传输。
  • 适用场景: 进程间的同步和资源管理。

性能比较

  • 共享内存通常是最快的,因为它避免了数据在内核和用户空间之间的拷贝,直接在内存中共享数据。
  • 管道和消息队列的性能次之,因为它们需要在内核和用户空间之间进行数据拷贝。
  • Socket通信的性能最慢,尤其是跨网络通信时,因为涉及网络协议栈的开销。

选择建议

  • 如果你需要在同一主机上的进程间进行高性能通信,共享内存是最佳选择。
  • 如果你需要简单的父子进程间通信,管道是一个不错的选择。
  • 如果你需要跨主机通信,Socket是唯一的选择。

 

 

Java代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import java.io.RandomAccessFile;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
 
public class Main {
    private final static String FILE_PATH = "D:\\shared_memory.bin3";
    private final static int FILE_SIZE =  1024*1024*1024 + 1; // 1MB for text, 1 byte for lock
 
    public static void main(String[] args) {
        try (RandomAccessFile memoryFile = new RandomAccessFile(FILE_PATH, "rw");
             FileChannel fileChannel = memoryFile.getChannel()) {
 
            MappedByteBuffer buffer = fileChannel.map(FileChannel.MapMode.READ_WRITE, 0, FILE_SIZE);
 
            String textToClassify = "什么是sql注入?";
 
            // Wait until the buffer is available for writing
            while (buffer.get(0) != 0) {
                System.out.println("Waiting for buffer to be available for writing...");
                Thread.sleep(100);
            }
 
            // Write the text to the buffer
            buffer.put(0, (byte) 1);  // Mark as written
            byte[] textBytes = textToClassify.getBytes(StandardCharsets.UTF_8);
            buffer.position(1);
            buffer.put(textBytes);
            // Pad the remaining part with 0x00 (null character in C, '\0' character)
            buffer.position(1 + textBytes.length);
            buffer.put(new byte[1024 - textBytes.length]);
            System.out.println("Written to memory: " + textToClassify);
 
            // Polling for result
            byte[] data = new byte[1024];
            String classificationResult = "";
            while (true) {
                if (buffer.get(0) == 0) {  // Check if buffer is available for writing
                    buffer.position(1);
                    buffer.get(data);
                    classificationResult = new String(data, StandardCharsets.UTF_8).trim();
                    System.out.println("Classification result received: " + classificationResult);
                    break;
                }
                Thread.sleep(1);  // Sleep for a short time before polling again
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

  

python代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# server.py
import json
import mmap
import time
import os
# from bert_classify import predict
# from sec_tool_rag import search_sectool_knowledge_base
 
def classify_text(text):
    return 100
    # predicted_label, prediction = predict(text)
    # return predicted_label
 
FILE_PATH = 'D:\\shared_memory.bin3'
FILE_SIZE = 1024*1024*1024 + 1  # 1MB bytes for text, 1 byte for lock
 
if not os.path.exists(FILE_PATH):
    with open(FILE_PATH, 'w+b') as f:
        f.write(b'\x00' * FILE_SIZE)
 
with open(FILE_PATH, 'r+b') as f:
    mm = mmap.mmap(f.fileno(), FILE_SIZE)
    while True:
        mm.seek(0)
        lock = mm.read_byte()
 
        if lock == 1# Data written and ready to be read
            mm.seek(1)
            message = mm.read(1024).decode('utf-8').rstrip('\x00').strip()
            print(f"Received message: {message}")
            # tools = search_sectool_knowledge_base(message, topk=3)
            # print(f"Sec tools for '{message}': {tools}")
 
            label = classify_text(message)
            print(f"Classified label: {label}")
            # str2write = json.dumps(tools)
            str2write = str(label)
            # Write the classification result
            mm.seek(1)
            mm.write(str2write.encode().ljust(1024, b'\x00'))
            mm.seek(0)
            mm.write_byte(0# Mark as available for writing
            print("Result written and memory marked as available for writing")
 
        time.sleep(0.001)

  

启动python程序,再启动java代码:

输出

1
2
3
4
5
6
7
8
Received message: 什么是sql注入?
Classified label: 100
Result written and memory marked as available for writing
 
------------------------------------------
 
Written to memory: 什么是sql注入?
Classification result received: 100

  

 

posted @   bonelee  阅读(114)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 记一次.NET内存居高不下排查解决与启示
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· DeepSeek 开源周回顾「GitHub 热点速览」
历史上的今天:
2022-06-25 自建门罗币矿池——todo,待实践
2022-06-25 事件分析|门罗币挖矿新家族「罗生门」——注意:这个挖矿可以设置CPU的阈值
2022-06-25 xmrig挖矿样本分析 miner
2021-06-25 Darknet暗网数据集——https://www.unb.ca/cic/datasets/darknet2020.html
2021-06-25 HIDS一般具有的功能
2021-06-25 利用CTU-13数据集进行僵尸网络检测
2021-06-25 twitter僵尸网路检测,只能twitter自己做这种算法
点击右上角即可分享
微信分享提示