代码改变世界

JPMML解析PMML模型并导入数据进行分析生成结果

2018-05-12 15:53  halberts  阅读(13276)  评论(2编辑  收藏  举报
JPMML解析Random Forest模型并使用其预测分析

导入Jar包

maven 的pom.xml文件中添加jpmml的依赖

<dependency>
        <groupId>org.jpmml</groupId>
        <artifactId>pmml-evaluator</artifactId>
        <version>1.3.7</version>
</dependency>

具体实现代码

模型读取类

import java.io.*;
import java.nio.charset.Charset;
import java.util.*;

import com.google.common.io.Files;
import org.dmg.pmml.FieldName;

/**
 * 使用模型
 * @author biantech
 *
 */
public class PmmlCalc {
    final static String utf8="utf-8";
    public static void main(String[] args) throws IOException {
        if(args.length < 2){
            System.out.println("参数个数不匹配");
        }
        //文件生成路径
        String pmmlPath = args[0];
        String modelArgsFilePath = args[1];
        PmmlInvoker invoker = new PmmlInvoker(pmmlPath);
        List<Map<FieldName, String>> paramList = readInParams(modelArgsFilePath);
        int lineNum = 0;  //当前处理行数
        File file = new File("result.txt");
        for(Map<FieldName, String> param : paramList){
            lineNum++;
            //System.out.println("======当前行: " + lineNum + "=======");
            Files.append("======当前行: " + lineNum + "=======",file,Charset.forName(utf8));
            Map<FieldName, ?> result = invoker.invoke(param);
            Set<FieldName> keySet = result.keySet();  //获取结果的keySet
            for(FieldName fn : keySet){
                String tempString = result.get(fn).toString()+"\n";
                Files.append(tempString,file,Charset.forName(utf8));
            }
        }
        System.out.println("resultFile="+file.getAbsolutePath());
    }

    /**
     * 读取参数文件
     * @param filePath 文件路径
     * @return
     * @throws IOException
     */
    public static List<Map<FieldName,String>> readInParams(String filePath) throws IOException{
        InputStream is;
        is = PmmlCalc.class.getClassLoader().getResourceAsStream(filePath);
        if(is==null){
            is = new FileInputStream(filePath);
        }
        InputStreamReader isreader = new InputStreamReader(is);
        BufferedReader br = new BufferedReader(isreader);
        String[] nameArr = br.readLine().split(",");  //读取表头的名字
        ArrayList<Map<FieldName,String>> list = new ArrayList<>();
        String paramLine;  //一行参数
        //循环读取  每次读取一行数据
        while((paramLine = br.readLine()) != null){
            Map<FieldName,String> map = new HashMap<>();
            String[] paramLineArr = paramLine.split(",");
            for(int i=0; i<paramLineArr.length; i++){//一次循环处理一行数据
                map.put(new FieldName(nameArr[i]), paramLineArr[i]); //将表头和值组成map 加入list中
            }
            list.add(map);
        }
        is.close();
        return list;
    }
}

调用执行类:PmmlInvoker

import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Map;
import javax.xml.bind.JAXBException;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.ModelEvaluatorFactory;
import org.jpmml.model.PMMLUtil;
import org.xml.sax.SAXException;
/**
 * 读取pmml 获取模型
 * @author biantech
 *
 */
public class PmmlInvoker {
    private ModelEvaluator modelEvaluator;
    // 通过文件读取模型
    public PmmlInvoker(String pmmlFileName) {
        PMML pmml = null;
        InputStream is = null;
        try {
            if (pmmlFileName != null) {
                is = PmmlInvoker.class.getClassLoader().getResourceAsStream(pmmlFileName);
                if(is==null){
                    is = new FileInputStream(pmmlFileName);
                }
                pmml = PMMLUtil.unmarshal(is);
            }
            this.modelEvaluator = ModelEvaluatorFactory.newInstance().newModelEvaluator(pmml);
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            try {
                if(is!=null)
                    is.close();
            } catch (Exception localIOException3) {
                localIOException3.printStackTrace();
            }
        }
        this.modelEvaluator.verify();
        System.out.println("模型读取成功");
    }

    // 通过输入流读取模型
    public PmmlInvoker(InputStream is) {
        PMML pmml;
        try {
            pmml = PMMLUtil.unmarshal(is);
            try {
                is.close();
            } catch (IOException localIOException) {

            }
            this.modelEvaluator = ModelEvaluatorFactory.newInstance().newModelEvaluator(pmml);
        } catch (SAXException e) {
            pmml = null;
        } catch (JAXBException e) {
            pmml = null;
        } finally {
            try {
                is.close();
            } catch (IOException localIOException3) {
            }
        }
        this.modelEvaluator.verify();
    }

    public Map<FieldName, String> invoke(Map<FieldName, String> paramsMap) {
        return this.modelEvaluator.evaluate(paramsMap);
    }
}

如何运行

  1. mvn package  命令生成 jpmml-parser-1-jar-with-dependencies.jar
  2. 将pmml文件, 数据集文件,jar 放在同一个目录下.(如 demo-model.pmml ,demo-data.csv)
  3. 使用命令行运行

    java -jar jpmml-parser-1-jar-with-dependencies.jar demo-model.pmml demo-data.csv

     

  4. 运行结束后会生成一个result.txt,里面存储的是对数据的预测分析结果

 

======当前行: 1=======ProbabilityDistribution{result=setosa, probability_entries=[setosa=1.0]}
setosa
1.0
0.0
0.0
======当前行: 2=======ProbabilityDistribution{result=setosa, probability_entries=[setosa=1.0]}
setosa
1.0
0.0
0.0
======当前行: 3=======ProbabilityDistribution{result=setosa, probability_entries=[setosa=1.0]}
setosa
1.0
0.0
0.0
======当前行: 4=======ProbabilityDistribution{result=setosa, probability_entries=[setosa=1.0]}
setosa
1.0
0.0
0.0
======当前行: 5=======ProbabilityDistribution{result=setosa, probability_entries=[setosa=1.0]}
setosa
1.0
0.0
0.0
======当前行: 6=======ProbabilityDistribution{result=setosa, probability_entries=[setosa=1.0]}
setosa
1.0
0.0
0.0

具体源代码请看如下地址

https://github.com/biantech/jpmml-parser