java调用pmml算法模型
==背景==
项目需要调用算法模型,用于优化工艺参数。
==思路==
根据团队不具备算法训练能力的特点,技术上采用“训练与使用分离”的策略,
即:模型训练寻找第三方渠道,将训练好的算法模型导出pmml文件,然后通过java调用模型文件。
==代码样例==
【准备模型】
找朋友要了了一个测试用模型文件。如下图的:模型case.mht
【python代码】
编写python代码,导出模型为pmml文件。
# 导入模块 from sklearn.datasets import load_iris from sklearn2pmml.pipeline import PMMLPipeline from sklearn2pmml import sklearn2pmml from sklearn2pmml import make_pmml_pipeline from sklearn.tree import DecisionTreeClassifier as DTC from sklearn import tree # 实例化 iris=load_iris() X=iris.data y=iris.target # 创建模型管道 pipeline = PMMLPipeline([ ("classifier", DTC(criterion = 'entropy' ,max_depth = 5, min_samples_leaf = 1,)) ]) # 训练模型 pipeline.fit(X,y) # 导出模型到 RandomForestClassifier_Iris.pmml ⽂件 sklearn2pmml(pipeline, "iris_test.pmml")
【编写java代码】
1. pom文件
<dependency> <groupId>org.jpmml</groupId> <artifactId>pmml-evaluator</artifactId> <version>1.4.1</version> </dependency> <dependency> <groupId>org.jpmml</groupId> <artifactId>pmml-evaluator-extension</artifactId> <version>1.4.1</version> </dependency>
2、java代码
package Pmml; import org.dmg.pmml.FieldName; import org.dmg.pmml.PMML; import org.jpmml.evaluator.*; import java.io.File; import java.io.FileInputStream; import java.io.InputStream; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; /** * 算法模型测试 * @author Chunhui.Qu */ public class PmmlTest { public static void main(String[] args) throws Exception { // 模型路径 String path = "D:\\work\\06_WorkFile\\蓝格赛\\机器学习\\iris_test_4_3.pmml"; // 传入模型特征数据 Map<String, Double> map1 = new HashMap<>(); map1.put("x1", 5.1); map1.put("x2", 3.5); map1.put("x3", 0.4); map1.put("x4", 0.2); // 模型预测 predictLrHeart(map1, path); } public static void predictLrHeart(Map<String, Double> irismap, String path) throws Exception { File file = new File(path); try (InputStream is = new FileInputStream(file)) { PMML pmml = org.jpmml.model.PMMLUtil.unmarshal(is); ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance(); ModelEvaluator<?> modelEvaluator = modelEvaluatorFactory.newModelEvaluator(pmml); List<InputField> inputFields = ((Evaluator) modelEvaluator).getInputFields(); Map<FieldName, FieldValue> argements = new LinkedHashMap<>(); for (InputField inputField : inputFields) { FieldName inputFieldName = inputField.getName(); Object raeValue = irismap.get(inputFieldName.getValue()); FieldValue inputFieldValue = inputField.prepare(raeValue); argements.put(inputFieldName, inputFieldValue); } Map<FieldName, ?> results = ((Evaluator) modelEvaluator).evaluate(argements); List<TargetField> targetFields = ((Evaluator) modelEvaluator).getTargetFields(); for (TargetField targetField : targetFields) { FieldName targetFieldName = targetField.getName(); Object targetFieldValue = results.get(targetFieldName); System.out.println(targetFieldValue); } } } }
【运行结果】
==踩坑一个==
java代码执行可能会遇到如下异常:
Exception in thread "main" java.lang.IllegalArgumentException: http://www.dmg.org/PMML-4_4 at org.dmg.pmml.Version.forNamespaceURI(Version.java:62) at org.jpmml.model.filters.PMMLFilter.updateSource(PMMLFilter.java:121) at org.jpmml.model.filters.PMMLFilter.startPrefixMapping(PMMLFilter.java:43) at com.sun.org.apache.xerces.internal.parsers.AbstractSAXParser.startNamespaceMapping(AbstractSAXParser.java:2169) at com.sun.org.apache.xerces.internal.parsers.AbstractSAXParser.startElement(AbstractSAXParser.java:470) at com.sun.org.apache.xerces.internal.impl.XMLNSDocumentScannerImpl.scanStartElement(XMLNSDocumentScannerImpl.java:375) at com.sun.org.apache.xerces.internal.impl.XMLNSDocumentScannerImpl$NSContentDriver.scanRootElementHook(XMLNSDocumentScannerImpl.java:614) at com.sun.org.apache.xerces.internal.impl.XMLDocumentFragmentScannerImpl$FragmentContentDriver.next(XMLDocumentFragmentScannerImpl.java:3134) at com.sun.org.apache.xerces.internal.impl.XMLDocumentScannerImpl$PrologDriver.next(XMLDocumentScannerImpl.java:867) at com.sun.org.apache.xerces.internal.impl.XMLDocumentScannerImpl.next(XMLDocumentScannerImpl.java:605) at com.sun.org.apache.xerces.internal.impl.XMLNSDocumentScannerImpl.next(XMLNSDocumentScannerImpl.java:113) at com.sun.org.apache.xerces.internal.impl.XMLDocumentFragmentScannerImpl.scanDocument(XMLDocumentFragmentScannerImpl.java:507) at com.sun.org.apache.xerces.internal.parsers.XML11Configuration.parse(XML11Configuration.java:867) at com.sun.org.apache.xerces.internal.parsers.XML11Configuration.parse(XML11Configuration.java:796) at com.sun.org.apache.xerces.internal.parsers.XMLParser.parse(XMLParser.java:142) at com.sun.org.apache.xerces.internal.parsers.AbstractSAXParser.parse(AbstractSAXParser.java:1216) at org.xml.sax.helpers.XMLFilterImpl.parse(XMLFilterImpl.java:357) at com.sun.xml.internal.bind.v2.runtime.unmarshaller.UnmarshallerImpl.unmarshal0(UnmarshallerImpl.java:243) at com.sun.xml.internal.bind.v2.runtime.unmarshaller.UnmarshallerImpl.unmarshal(UnmarshallerImpl.java:214) at javax.xml.bind.helpers.AbstractUnmarshallerImpl.unmarshal(AbstractUnmarshallerImpl.java:140) at javax.xml.bind.helpers.AbstractUnmarshallerImpl.unmarshal(AbstractUnmarshallerImpl.java:123) at org.jpmml.model.JAXBUtil.unmarshal(JAXBUtil.java:82) at org.jpmml.model.JAXBUtil.unmarshalPMML(JAXBUtil.java:68) at org.jpmml.model.PMMLUtil.unmarshal(PMMLUtil.java:35) at Pmml.PmmlTest.predictLrHeart(PmmlTest.java:38) at Pmml.PmmlTest.main(PmmlTest.java:32)
解决办法:打开pmml文件,将版本由4_4调整为4_3。
调整前:http://www.dmg.org/PMML-4_4
调整后:http://www.dmg.org/PMML-4_3
==参考博客==
https://blog.csdn.net/qq_32113189/article/details/107541890
https://blog.csdn.net/qq_32113189/article/details/107542225