解析lightgbm的txt模型文件

根据近期的github方案,实现对txt格式的pmml文件的加载

添加依赖

<dependency>
    <groupId>org.jpmml</groupId>
    <artifactId>pmml-lightgbm</artifactId>
    <version>1.5.4</version>
</dependency>
<dependency>
    <groupId>org.jpmml</groupId>
    <artifactId>pmml-evaluator</artifactId>
    <version>1.6.6</version>
</dependency>
<dependency>
    <groupId>org.jpmml</groupId>
    <artifactId>pmml-model</artifactId>
    <version>1.6.6</version>
</dependency>

工具类

import lombok.extern.slf4j.Slf4j;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.InputField;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.ModelEvaluatorBuilder;
import org.jpmml.evaluator.ModelEvaluatorFactory;
import org.jpmml.evaluator.ProbabilityDistribution;
import org.jpmml.evaluator.TargetField;
import org.jpmml.lightgbm.GBDT;
import org.jpmml.lightgbm.HasLightGBMOptions;
import org.jpmml.lightgbm.LightGBMUtil;
import org.jpmml.model.metro.MetroJAXBUtil;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;

import java.io.File;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/**
 * 加载、初始化 PMML模型文件 :
 * 依赖 pmml-lightgbm-1.5.0(AGPL-3.0 License)
 * <p>
 * 解析PMML文件 @link https://github.com/jpmml/jpmml-lightgbm
 * 生成evaluator @link https://github.com/jpmml/jpmml-evaluator
 */
@Slf4j
public class LightgbmTxtInitializer {

    // description = "Custom objective function"
    private static String objectiveFunction = null;
    // description = "Transform LightGBM-style trees to PMML-style trees",
    private static boolean compact = true;
    // description = "Treat Not-a-Number (NaN) values as missing values",
    private static boolean nanAsMissing = true;
    // description = "Limit the number of trees. Defaults to all trees"
    private static Integer numIteration = null;
    // description = "Target name. Defaults to \"_target\""
    private static String targetName = null;
    // description = "Target categories. Defaults to 0-based index [0, 1, .., num_class - 1]"
    private static List<String> targetCategories = null;


    public static void main(String[] output) throws Exception {
        Resource resource = new ClassPathResource("lightgbm_model.txt");
        InputStream pmmlFileInputStream = resource.getInputStream();
        // 生成模型执行器
        ModelEvaluator evaluator = initEvaluator(pmmlFileInputStream);
        // 打印特征参数
        List<InputField> inputFields = evaluator.getInputFields();
        log.info("ModelEvaluator featureNames:" + inputFields);
        // 调试执行预测
        Map<String, Number> waitPreSample = new HashMap<>(8);
        waitPreSample.put("0", 0.1);
        waitPreSample.put("1", 0.2);
        waitPreSample.put("2", 0.3);
        String predictedValue = getPredictedValue(waitPreSample, evaluator);

        pmmlFileInputStream.close();
    }

    public static ModelEvaluator initEvaluator(InputStream pmmlFileInputStream) throws Exception {
        GBDT gbdt;
        long begin = System.currentTimeMillis();

        gbdt = LightGBMUtil.loadGBDT(pmmlFileInputStream);
        log.info("Loaded GBDT in {} ms.", (System.currentTimeMillis() - begin));

        if (objectiveFunction != null) {
            log.info("Setting custom objective function");
            gbdt.setObjectiveFunction(LightGBMUtil.parseObjectiveFunction(objectiveFunction));
        }
        Map<String, Object> options = new LinkedHashMap<>();
        options.put(HasLightGBMOptions.OPTION_COMPACT, compact);
        options.put(HasLightGBMOptions.OPTION_NAN_AS_MISSING, nanAsMissing);
        options.put(HasLightGBMOptions.OPTION_NUM_ITERATION, numIteration);

        // 生成标准PMML
        begin = System.currentTimeMillis();
        PMML pmml;
        pmml = gbdt.encodePMML(options, targetName, targetCategories);
        long end = System.currentTimeMillis();
        log.info("Converted GBDT to PMML in {} ms.", (System.currentTimeMillis() - begin));

        // no need
        // 输出PMML格式文件
        begin = System.currentTimeMillis();
        File outputFile = new File("E://t.pmml");
        OutputStream os = new FileOutputStream(outputFile);
        MetroJAXBUtil.marshalPMML(pmml, os);
        log.info("Marshalled PMML in {} ms.", (System.currentTimeMillis() - begin));

        // 生成evaluator
        begin = System.currentTimeMillis();
        ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
        ModelEvaluatorBuilder modelEvaluatorBuilder = new ModelEvaluatorBuilder(pmml);
        modelEvaluatorBuilder.setModelEvaluatorFactory(modelEvaluatorFactory);
        ModelEvaluator<?> evaluator = modelEvaluatorBuilder.build();
        evaluator.verify();
        log.info("Init evaluator in {} ms.", (System.currentTimeMillis() - begin));
        return evaluator;
    }


    public static String getPredictedValue(Map<String, ?> argumentMap,
                                           ModelEvaluator<?> evaluator) {
        // 预测计算
        Map<String, ?> evaluateResult = evaluator.evaluate(argumentMap);
        log.info("evaluateResult:" + evaluateResult);
        // 提取预测结果
        String predictedValue = null;
        TargetField targetFieldName = evaluator.getTargetField();
        Object targetFieldValue = evaluateResult.get(targetFieldName.getFieldName());
        // 输出预测结果
        if (targetFieldValue instanceof ProbabilityDistribution) {
            predictedValue = ((ProbabilityDistribution<?>) targetFieldValue).getPrediction().toString();
            log.info("Predicted value(ProbabilityDistribution) : " + predictedValue);
        } else if (targetFieldValue instanceof FieldValue) {
            FieldValue fieldValue = (FieldValue) targetFieldValue;
            predictedValue = fieldValue.asString();
            log.info("Predicted value(FieldValue) : " + predictedValue);
        } else if (targetFieldValue instanceof List) {
            List<String> resultList =
                    ((List<?>) targetFieldValue)
                            .stream()
                            .map(e -> ((FieldValue) e).asString())
                            .collect(Collectors.toList());
            predictedValue = String.join(",", resultList);
            log.info("Predicted value(List) : " + predictedValue);
        } else {
            log.error("unknown type for targetFieldValue:" + targetFieldValue);
        }
        return predictedValue;
    }
}
posted @ 2024-11-08 14:06  鱼007  阅读(22)  评论(0编辑  收藏  举报