使用PMML部署机器学习模型
PMML简介
预测模型标记语言PMML(Predictive Model Markup Language)是一套与平台和环境无关的模型表示语言,是目前表示机器学习模型的实际标准。
作为一个开放的成熟标准,PMML由数据挖掘组织DMG(Data Mining Group)开发和维护,经过十几年的发展,得到了广泛的应用,有超过30家厂商和开源项目(包括SAS,IBM SPSS,KNIME,RapidMiner等主流厂商)在它们的数据挖掘分析产品中支持并应用PMML,这些厂商应用详情见下表:PMML Powered
PMML标准介绍
PMML是一套基于XML的标准,通过 XML Schema 定义了使用的元素和属性,主要由以下核心部分组成:
数据字典(Data Dictionary),描述输入数据。
数据转换(Transformation Dictionary和Local Transformations),应用在输入数据字段上生成新的派生字段。
模型定义 (Model),每种模型类型有自己的定义。
输出(Output),指定模型输出结果。
PMML预测过程符合数据挖掘分析流程:
pmml-flow.png
PMML优点
平台无关性。PMML可以让模型部署环境脱离开发环境,实现跨平台部署,是PMML区别于其他模型部署方法最大的优点。比如使用Python建立的模型,导出PMML后可以部署在Java生产环境中。
互操作性。这就是标准协议的最大优势,实现了兼容PMML的预测程序可以读取其他应用导出的标准PMML模型。
广泛支持性。已取得30余家厂商和开源项目的支持,通过已有的多个开源库,很多重量级流行的开源数据挖掘模型都可以转换成PMML。
可读性。PMML模型是一个基于XML的文本文件,使用任意的文本编辑器就可以打开并查看文件内容,比二进制序列化文件更安全可靠。
PMML开源类库
模型转换库,生成PMML:
Python模型:
Nyoka,支持Scikit-Learn,LightGBM,XGBoost,Statsmodels和Keras。https://github.com/nyoka-pmml/nyoka
JPMML系列,比如JPMML-SkLearn、JPMML-XGBoost、JPMML-LightGBM等,提供命令行程序导出模型到PMML。https://github.com/jpmml
R模型:
R pmml包:https://cran.r-project.org/web/packages/pmml/index.html
r2pmml:https://github.com/jpmml/r2pmml
JPMML-R:提供命令行程序导出R模型到PMML,https://github.com/jpmml/jpmml-r
Spark:
Spark mllib,但是只是模型本身,不支持Pipelines,不推荐使用。
JPMML-SparkML,支持Spark ML pipleines。https://github.com/jpmml/jpmml-sparkml
模型评估库,读取PMML:
Java:
JPMML-Evaluator,纯Java的PMML预测库,开源协议是AGPL V3。https://github.com/jpmml/jpmml-evaluator
PMML4S,使用Scala开发,同时提供Scala和Java API,接口简单好用,开源协议是常用的宽松协议Apache 2。https://github.com/autodeployai/pmml4s
Python:
PyPMML,PMML的Python预测库,PyPMML是PMML4S包装的Python接口。https://github.com/autodeployai/pypmml
Spark:
JPMML-Evaluator-Spark,https://github.com/jpmml/jpmml-evaluator-spark
PMML4S-Spark,https://github.com/autodeployai/pmml4s-spark
PySpark:
PyPMML-Spark,PySpark中预测PMML模型。https://github.com/autodeployai/pypmml-spark
REST API:
AI-Serving,同时为PMML模型提供REST和gRPC API,开源协议Apache 2。https://github.com/autodeployai/ai-serving
Openscoring,提供REST API,开源协议AGPL V3。https://github.com/openscoring/openscoring
sparkml训练完模型后保存模型为PMML文件:
model.toPMML(spark.sparkContext, "G:\pmml\spark\lr\xml")
java 使用pmml4s加载pmml文件示例:
pmml文件:
<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<PMML version="4.2" xmlns="http://www.dmg.org/PMML-4_2">
<Header description="logistic regression">
<Application name="Apache Spark MLlib" version="2.4.0"/>
<Timestamp>2020-11-25T19:52:36</Timestamp>
</Header>
<DataDictionary numberOfFields="5">
<DataField name="field_0" optype="continuous" dataType="double"/>
<DataField name="field_1" optype="continuous" dataType="double"/>
<DataField name="field_2" optype="continuous" dataType="double"/>
<DataField name="field_3" optype="continuous" dataType="double"/>
<DataField name="target" optype="categorical" dataType="string"/>
</DataDictionary>
<RegressionModel modelName="logistic regression" functionName="classification" normalizationMethod="logit">
<MiningSchema>
<MiningField name="field_0" usageType="active"/>
<MiningField name="field_1" usageType="active"/>
<MiningField name="field_2" usageType="active"/>
<MiningField name="field_3" usageType="active"/>
<MiningField name="target" usageType="target"/>
</MiningSchema>
<RegressionTable intercept="0.0" targetCategory="1">
<NumericPredictor name="field_0" coefficient="-5.95759188680503"/>
<NumericPredictor name="field_1" coefficient="-1.6974588567364868"/>
<NumericPredictor name="field_2" coefficient="-5.660350922982105"/>
<NumericPredictor name="field_3" coefficient="8.680992926976252"/>
</RegressionTable>
<RegressionTable intercept="-0.0" targetCategory="0"/>
</RegressionModel>
</PMML>
pom文件添加依赖:
<dependency>
<groupId>org.pmml4s</groupId>
<artifactId>pmml4s_2.12</artifactId>
<version>0.9.7</version>
</dependency>
java代码:
import org.pmml4s.model.Model;
import java.util.HashMap;
import java.util.Map;
public class PMML4SDemo {
public static void main(String[] args) {
Model model = Model.fromFile("G:\\pmml\\spark\\lr\\xml\\lr.xml");
Map<String, Object> result = model.predict(new HashMap<String, Object>() {{
put("field_0", 2);
put("field_1", 4);
put("field_2", 1);
put("field_3", 5);
}});
System.out.println(result);
System.out.println(result.get("predicted_target"));
}
}