JAVA加载PMML算法模型

注:加载失败时尝试修改pmml文件版本为4.3

依赖

复制代码
<dependency>
    <groupId>org.jpmml</groupId>
    <artifactId>pmml-evaluator</artifactId>
    <version>1.4.1</version>
</dependency>
<dependency>
    <groupId>org.jpmml</groupId>
    <artifactId>pmml-evaluator-extension</artifactId>
    <version>1.4.1</version>
</dependency>

------------resources添加--------------
<resources>
    <resource>
        <directory>src/main/resources</directory>
        <filtering>true</filtering>
        <includes>
            <include>application.yml</include>
            <include>application-${activatedProperties}.yml</include>
            <include>**/**.pmml</include>
        </includes>
    </resource>
</resources>
复制代码

 

加载模型
复制代码
@Bean
@Qualifier("evaluator")
public Evaluator load() {
    PMML pmml = null;
    try {
        ClassPathResource classPathResource = new ClassPathResource("META-INF/xxx.pmml");
        InputStream is = classPathResource.getInputStream();
        pmml = PMMLUtil.unmarshal(is);
    } catch (IOException e) {
        log.error("Get resource:xxx.pmml failed! error msg:{}", e.getMessage());
    } catch (JAXBException | SAXException e) {
        log.error(e.getMessage());
    }
    ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
    modelEvaluator = (Evaluator) modelEvaluatorFactory.newModelEvaluator(pmml);
    modelEvaluator.verify();
    log.info("load model successful!");
    return modelEvaluator;
}
复制代码

 

调用模型
复制代码
@Autowired
@Qualifier("evaluator")
private Evaluator evaluator;


@Test
    public void moduleTest() {
        List<InputField> inputFields = evaluator.getInputFields();
        Map<String, Object> map = new LinkedHashMap<>();
        map.put("x1", 50);
        map.put("x2", 50);
        map.put("x3", 7);
        map.put("x4", 60);
        map.put("x5", 250);
        map.put("x6", 6);
        map.put("x7", 80);
        map.put("x8", 260);
        map.put("x9", 48);
        map.put("x10", 1);
        map.put("x11", 5);
        map.put("x12", 1);
        map.put("x13", 0);
        Map<FieldName, FieldValue> args = new HashMap<>();
        for (InputField inputField : inputFields) {
            FieldName name = inputField.getName();
            Object value = map.get(name.getValue());
            FieldValue inputFieldValue = inputField.prepare(value);
            args.put(name, inputFieldValue);
        }
        Map<FieldName, ?> evaluate = evaluator.evaluate(args);
        Object value = evaluate.get(evaluator.getTargetFields().get(0).getName());
        BigDecimal predictValue = JSON.parseObject(JSON.toJSONString(value)).getJSONObject("values").getJSONObject("1.0").getBigDecimal("value");
        log.info("result:{}", predictValue);
    }
复制代码

 

 

posted @   余额一个亿  阅读(387)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 记一次.NET内存居高不下排查解决与启示
· DeepSeek 开源周回顾「GitHub 热点速览」
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
点击右上角即可分享
微信分享提示