Py4j的简单使用

一、

  Py4j官网:https://www.py4j.org/,介绍大家可以到官网进行了解。

二、

  本次介绍的是通过java调用python的函数,以及python的返回值给到java。

三、

  演示:

  java代码

import py4j.GatewayServer;

import java.net.InetAddress;

public class ExampleClientApplication {

public static void main(String[] args) throws Exception{
for(int i =0;i<3;i++) {
GatewayServer.turnLoggingOff();
GatewayServer.GatewayServerBuilder builder = new GatewayServer.GatewayServerBuilder();
GatewayServer server = builder.javaAddress(InetAddress.getByName("10.252.218.147")).build();
// GatewayServer server = new GatewayServer();
server.start();
IHello hello = (IHello) server.getPythonServerEntryPoint(new Class[] { IHello.class });
try {
// System.out.println(hello.sayHello());
// System.out.println(hello.sayHello(2, "Hello World", "/Users/admin/Downloads/临时文件/pycharmProject/标签分析_0_2023-04-24+18_32_16.xlsx"));
// System.out.println(hello.sayHello());
System.out.println(hello.sayHello2("a", "b", "c"));
} catch (Exception e) {
e.printStackTrace();
}
server.shutdown();
Thread.sleep(1000L);
}

}
}
public interface IHello {
    public String sayHello();

    public String sayHello(int i, String s, String path);

    public String sayHello2(String s1, String s2, String s3);
}

  其中:IHello是一个java的接口,调用的是sayHello方法(带3个参数),实质调用的是python侧的该函数,java没有写其实现类。

  python代码

from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import confusion_matrix, accuracy_score, recall_score, precision_score, f1_score
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import load_iris
import pandas as pd
import json
from sklearn import tree
import graphviz
import pydotplus
import matplotlib

matplotlib.use('agg')


class SimpleHello(object):


    def sayHello2(self, s1, s2, s3):
        return s1 + '_' + s2 + '_' + s3

    def sayHello(self, int_value=None, string_value=None, path=None):
        # print(int_value, string_value)
        # return "Said hello to {0}".format(string_value)

        # Load the iris dataset
        # X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
        # y = np.array([0, 1, 1, 0])
        m = {}
        # iris = load_iris()
        # X = iris.data
        # y = iris.target

        # df = pd.read_excel("/Users/admin/Downloads/临时文件/pycharmProject/标签分析_0_2023-04-24+18_32_16.xlsx")
        df = pd.read_excel(path)
        df1 = df
        df2 = df.drop(columns=['flag', 'boss_id', 'ds'])
        X = np.array(df2)
        # print(X)
        y = np.array(df['flag'])
        # print(y)

        # Train the decision tree model
        dt = DecisionTreeClassifier(max_depth=4, min_samples_leaf=1, criterion='entropy')
        dt.fit(X, y)

        # Calculate feature importance
        importance = dt.feature_importances_
        print("Feature Importance:", importance)
        m['Importance'] = importance.tolist()
        print(type(importance))

        # Make predictions on the test data
        y_pred = dt.predict(X)

        # Calculate confusion matrix
        cm = confusion_matrix(y, y_pred)
        print(type(cm))
        print("Confusion Matrix:\n", cm)
        m["matrix"] = cm.tolist()

        # Calculate accuracy
        accuracy = accuracy_score(y, y_pred)
        print("Accuracy:", accuracy)
        m["Accuracy"] = accuracy

        # Calculate recall
        recall = recall_score(y, y_pred, average="macro")
        print("Recall:", recall)
        m["Recall"] = recall

        # Calculate precision
        precision = precision_score(y, y_pred, average="macro")
        print("Precision:", precision)
        m["Precision"] = precision

        # Calculate F1 score
        f1 = f1_score(y, y_pred, average="macro")
        print("F1 Score:", f1)
        m["F"] = f1

        # Plot the decision tree and save it
        # plt.figure(figsize=(10, 8))
        # mpl.rcParams['axes.unicode_minus'] = False
        plt.rcParams['font.sans-serif'] = ['SimHei']
        # print(iris.feature_names)
        # print(type(iris.feature_names))
        # print(iris.target_names)
        # print(type(iris.target_names))

        # plot_tree(dt, filled=True, feature_names=np.array(df1.columns), class_names=np.array(['1', '0']))
        # plt.savefig("./decision_tree1.png", dpi=300)

        with open("./model.dot", 'w') as f:
            f = tree.export_graphviz(dt, out_file=f)

        #     # 画图,保存到pdf文件
        #     # 设置图像参数
        dot_data = tree.export_graphviz(dt, out_file=None,
                                        feature_names=np.array(df2.columns),
                                        filled=True, rounded=True,
                                        special_characters=True)
        dot_data = dot_data.replace("helvetica", "MicrosoftYaHei")
        graph = pydotplus.graph_from_dot_data(dot_data)
        #     # 保存图像到pdf文件
        # graph.write_pdf("./model_4.pdf")
        graph.write_svg("./model_5.svg")

        return json.dumps(m)

    class Java:
        implements = ["org.example.java_call_python.IHello"]


# Make sure that the python code is started first.
# Then execute: java -cp py4j.jar py4j.examples.SingleThreadClientApplication
from py4j.java_gateway import JavaGateway, CallbackServerParameters, GatewayParameters
#
simple_hello = SimpleHello()
gateway_param = GatewayParameters(address="10.252.218.147")
gateway = JavaGateway(
    gateway_parameters=gateway_param,
    callback_server_parameters=CallbackServerParameters(),
    python_server_entry_point=simple_hello)

  使用方式:先执行python程序,在运行java程序,二者通过GatewayServer进行通信。

  效果展示:

 

   这样就达到了在java调用python的函数进行返回,博主的场景是在java中调用python的决策树算法,大家可以自行进行函数内部逻辑封装即可。

 

posted @ 2023-05-25 10:19  Coding_Now  阅读(1077)  评论(0编辑  收藏  举报