解析lightgbm的txt模型文件
根据近期的github方案,实现对txt格式的pmml文件的加载
添加依赖
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-lightgbm</artifactId>
<version>1.5.4</version>
</dependency>
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator</artifactId>
<version>1.6.6</version>
</dependency>
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-model</artifactId>
<version>1.6.6</version>
</dependency>
工具类
import lombok.extern.slf4j.Slf4j;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.InputField;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.ModelEvaluatorBuilder;
import org.jpmml.evaluator.ModelEvaluatorFactory;
import org.jpmml.evaluator.ProbabilityDistribution;
import org.jpmml.evaluator.TargetField;
import org.jpmml.lightgbm.GBDT;
import org.jpmml.lightgbm.HasLightGBMOptions;
import org.jpmml.lightgbm.LightGBMUtil;
import org.jpmml.model.metro.MetroJAXBUtil;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;
import java.io.File;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
/**
* 加载、初始化 PMML模型文件 :
* 依赖 pmml-lightgbm-1.5.0(AGPL-3.0 License)
* <p>
* 解析PMML文件 @link https://github.com/jpmml/jpmml-lightgbm
* 生成evaluator @link https://github.com/jpmml/jpmml-evaluator
*/
@Slf4j
public class LightgbmTxtInitializer {
// description = "Custom objective function"
private static String objectiveFunction = null;
// description = "Transform LightGBM-style trees to PMML-style trees",
private static boolean compact = true;
// description = "Treat Not-a-Number (NaN) values as missing values",
private static boolean nanAsMissing = true;
// description = "Limit the number of trees. Defaults to all trees"
private static Integer numIteration = null;
// description = "Target name. Defaults to \"_target\""
private static String targetName = null;
// description = "Target categories. Defaults to 0-based index [0, 1, .., num_class - 1]"
private static List<String> targetCategories = null;
public static void main(String[] output) throws Exception {
Resource resource = new ClassPathResource("lightgbm_model.txt");
InputStream pmmlFileInputStream = resource.getInputStream();
// 生成模型执行器
ModelEvaluator evaluator = initEvaluator(pmmlFileInputStream);
// 打印特征参数
List<InputField> inputFields = evaluator.getInputFields();
log.info("ModelEvaluator featureNames:" + inputFields);
// 调试执行预测
Map<String, Number> waitPreSample = new HashMap<>(8);
waitPreSample.put("0", 0.1);
waitPreSample.put("1", 0.2);
waitPreSample.put("2", 0.3);
String predictedValue = getPredictedValue(waitPreSample, evaluator);
pmmlFileInputStream.close();
}
public static ModelEvaluator initEvaluator(InputStream pmmlFileInputStream) throws Exception {
GBDT gbdt;
long begin = System.currentTimeMillis();
gbdt = LightGBMUtil.loadGBDT(pmmlFileInputStream);
log.info("Loaded GBDT in {} ms.", (System.currentTimeMillis() - begin));
if (objectiveFunction != null) {
log.info("Setting custom objective function");
gbdt.setObjectiveFunction(LightGBMUtil.parseObjectiveFunction(objectiveFunction));
}
Map<String, Object> options = new LinkedHashMap<>();
options.put(HasLightGBMOptions.OPTION_COMPACT, compact);
options.put(HasLightGBMOptions.OPTION_NAN_AS_MISSING, nanAsMissing);
options.put(HasLightGBMOptions.OPTION_NUM_ITERATION, numIteration);
// 生成标准PMML
begin = System.currentTimeMillis();
PMML pmml;
pmml = gbdt.encodePMML(options, targetName, targetCategories);
long end = System.currentTimeMillis();
log.info("Converted GBDT to PMML in {} ms.", (System.currentTimeMillis() - begin));
// no need
// 输出PMML格式文件
begin = System.currentTimeMillis();
File outputFile = new File("E://t.pmml");
OutputStream os = new FileOutputStream(outputFile);
MetroJAXBUtil.marshalPMML(pmml, os);
log.info("Marshalled PMML in {} ms.", (System.currentTimeMillis() - begin));
// 生成evaluator
begin = System.currentTimeMillis();
ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
ModelEvaluatorBuilder modelEvaluatorBuilder = new ModelEvaluatorBuilder(pmml);
modelEvaluatorBuilder.setModelEvaluatorFactory(modelEvaluatorFactory);
ModelEvaluator<?> evaluator = modelEvaluatorBuilder.build();
evaluator.verify();
log.info("Init evaluator in {} ms.", (System.currentTimeMillis() - begin));
return evaluator;
}
public static String getPredictedValue(Map<String, ?> argumentMap,
ModelEvaluator<?> evaluator) {
// 预测计算
Map<String, ?> evaluateResult = evaluator.evaluate(argumentMap);
log.info("evaluateResult:" + evaluateResult);
// 提取预测结果
String predictedValue = null;
TargetField targetFieldName = evaluator.getTargetField();
Object targetFieldValue = evaluateResult.get(targetFieldName.getFieldName());
// 输出预测结果
if (targetFieldValue instanceof ProbabilityDistribution) {
predictedValue = ((ProbabilityDistribution<?>) targetFieldValue).getPrediction().toString();
log.info("Predicted value(ProbabilityDistribution) : " + predictedValue);
} else if (targetFieldValue instanceof FieldValue) {
FieldValue fieldValue = (FieldValue) targetFieldValue;
predictedValue = fieldValue.asString();
log.info("Predicted value(FieldValue) : " + predictedValue);
} else if (targetFieldValue instanceof List) {
List<String> resultList =
((List<?>) targetFieldValue)
.stream()
.map(e -> ((FieldValue) e).asString())
.collect(Collectors.toList());
predictedValue = String.join(",", resultList);
log.info("Predicted value(List) : " + predictedValue);
} else {
log.error("unknown type for targetFieldValue:" + targetFieldValue);
}
return predictedValue;
}
}