Alink 模型保存与模型加载
1、pom.xml
<?xml version="1.0" encoding="UTF-8"?> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> <parent> <artifactId>FlinkSql</artifactId> <groupId>org.example</groupId> <version>1.0-SNAPSHOT</version> </parent> <modelVersion>4.0.0</modelVersion> <artifactId>FlinkML</artifactId> <properties> <flink12.version>1.12.1</flink12.version> <scala.binary.version>2.11</scala.binary.version> <!-- flink-cdc版本为1.3.0,支持binlog文件和pos启动--> <flink-cdc.version>1.2.0</flink-cdc.version> <hive.version>1.1.0</hive.version> <alink.version>1.4.0</alink.version> </properties> <dependencies> <dependency> <groupId>com.alibaba.alink</groupId> <artifactId>alink_core_flink-1.12_2.11</artifactId> <version>${alink.version}</version> </dependency> <!-- Flink Dependency --> <dependency> <groupId>org.apache.flink</groupId> <artifactId>flink-connector-hive_2.11</artifactId> <version>${flink12.version}</version> <!--<scope>provided</scope>--> </dependency> <dependency> <groupId>org.apache.flink</groupId> <artifactId>flink-table-api-java-bridge_2.11</artifactId> <version>${flink12.version}</version> <!--<scope>provided</scope>--> </dependency> <dependency> <groupId>org.apache.flink</groupId> <artifactId>flink-table-planner_2.11</artifactId> <version>${flink12.version}</version> </dependency> <!-- https://mvnrepository.com/artifact/org.apache.flink/flink-table-planner-blink --> <dependency> <groupId>org.apache.flink</groupId> <artifactId>flink-table-planner-blink_2.11</artifactId> <version>${flink12.version}</version> <!--<scope>provided</scope>--> </dependency> <!-- https://mvnrepository.com/artifact/org.apache.hadoop/hadoop-common --> <!--<dependency>--> <!--<groupId>org.apache.hadoop</groupId>--> <!--<artifactId>hadoop-common</artifactId>--> <!--<version>2.6.0</version>--> <!--</dependency>--> <!-- https://mvnrepository.com/artifact/org.apache.flink/flink-hadoop-compatibility --> <dependency> <groupId>org.apache.flink</groupId> <artifactId>flink-hadoop-compatibility_2.11</artifactId> <version>${flink12.version}</version> </dependency> <!-- Hive Dependency --> <dependency> <groupId>org.apache.hive</groupId> <artifactId>hive-exec</artifactId> <version>1.1.0</version> <!--<scope>provided</scope>--> </dependency> <!-- https://mvnrepository.com/artifact/org.apache.hive/hive-metastore --> <dependency> <groupId>org.apache.hive</groupId> <artifactId>hive-metastore</artifactId> <version>1.1.0</version> </dependency> <!-- https://mvnrepository.com/artifact/org.apache.thrift/libfb303 --> <dependency> <groupId>org.apache.thrift</groupId> <artifactId>libfb303</artifactId> <version>0.9.2</version> <!--<type>pom</type>--> </dependency> <!-- https://mvnrepository.com/artifact/org.apache.flink/flink-shaded-hadoop-2-uber --> <dependency> <groupId>org.apache.flink</groupId> <artifactId>flink-shaded-hadoop-2-uber</artifactId> <version>2.6.5-7.0</version> <!--<scope>provided</scope>--> </dependency> <dependency> <groupId>org.apache.flink</groupId> <artifactId>flink-connector-jdbc_2.11</artifactId> <version>${flink12.version}</version> </dependency> <!--format--> <dependency> <groupId>org.apache.flink</groupId> <artifactId>flink-json</artifactId> <version>${flink12.version}</version> </dependency> <!-- https://mvnrepository.com/artifact/org.apache.flink/flink-streaming-java --> <dependency> <groupId>org.apache.flink</groupId> <artifactId>flink-streaming-java_2.11</artifactId> <version>${flink12.version}</version> <!--<scope>provided</scope>--> </dependency> <!-- https://mvnrepository.com/artifact/org.apache.flink/flink-clients --> <dependency> <groupId>org.apache.flink</groupId> <artifactId>flink-clients_2.11</artifactId> <version>${flink12.version}</version> </dependency> <!-- https://mvnrepository.com/artifact/org.apache.flink/flink-core --> <dependency> <groupId>org.apache.flink</groupId> <artifactId>flink-core</artifactId> <version>${flink12.version}</version> </dependency> <dependency> <groupId>mysql</groupId> <artifactId>mysql-connector-java</artifactId> <version>5.1.44</version> </dependency> <dependency> <groupId>com.alibaba</groupId> <artifactId>fastjson</artifactId> <version>RELEASE</version> <scope>compile</scope> </dependency> <!-- https://mvnrepository.com/artifact/org.apache.flink/flink-connector-kafka --> <dependency> <groupId>org.apache.flink</groupId> <artifactId>flink-connector-kafka_2.11</artifactId> <version>${flink12.version}</version> </dependency> <!-- https://mvnrepository.com/artifact/redis.clients/jedis --> <dependency> <groupId>redis.clients</groupId> <artifactId>jedis</artifactId> <version>3.2.0</version> </dependency> <!-- https://mvnrepository.com/artifact/org.apache.bahir/flink-connector-redis --> <dependency> <groupId>org.apache.bahir</groupId> <artifactId>flink-connector-redis_2.11</artifactId> <version>1.0</version> </dependency> <!-- https://mvnrepository.com/artifact/org.apache.flink/flink-connector-elasticsearch-base --> <!-- <dependency>--> <!-- <groupId>org.apache.flink</groupId>--> <!-- <artifactId>flink-connector-elasticsearch-base_2.11</artifactId>--> <!-- <version>${flink12.version}</version>--> <!-- </dependency>--> <dependency> <groupId>com.alibaba.ververica</groupId> <!-- add the dependency matching your database --> <artifactId>flink-connector-mysql-cdc</artifactId> <version>1.2.0</version> </dependency> <dependency> <groupId>org.apache.flink</groupId> <artifactId>flink-orc-nohive_2.11</artifactId> <version>${flink12.version}</version> </dependency> <dependency> <groupId>org.apache.flink</groupId> <artifactId>flink-statebackend-rocksdb_2.11</artifactId> <version>${flink12.version}</version> </dependency> <dependency> <groupId>org.apache.flink</groupId> <artifactId>flink-connector-hbase-1.4_2.11</artifactId> <version>${flink12.version}</version> </dependency> <dependency> <groupId>com.google.guava</groupId> <artifactId>guava</artifactId> <version>19.0</version> </dependency> <dependency> <groupId>ru.yandex.clickhouse</groupId> <artifactId>clickhouse-jdbc</artifactId> <version>0.2.4</version> </dependency> <dependency> <groupId>org.apache.flink</groupId> <artifactId>flink-streaming-scala_2.11</artifactId> <version>${flink12.version}</version> </dependency> <!--打印详细信息--> <!-- <dependency>--> <!-- <groupId>org.slf4j</groupId>--> <!-- <artifactId>slf4j-simple</artifactId>--> <!-- <version>1.7.25</version>--> <!-- <!– <scope>test</scope>–>--> <!-- </dependency>--> <dependency> <groupId>org.apache.flink</groupId> <artifactId>flink-csv</artifactId> <version>${flink12.version}</version> </dependency> <dependency> <groupId>org.apache.flink</groupId> <artifactId>flink-connector-elasticsearch6_2.11</artifactId> <version>${flink12.version}</version> </dependency> <dependency> <groupId>com.alibaba.ververica</groupId> <artifactId>flink-format-changelog-json</artifactId> <version>1.1.0</version> </dependency> <!-- 日志相关依赖,flink必须要加,否则报错,加了hive,冲突了 --> <dependency> <groupId>org.slf4j</groupId> <artifactId>slf4j-api</artifactId> <version>1.7.25</version> </dependency> <dependency> <groupId>org.slf4j</groupId> <artifactId>slf4j-log4j12</artifactId> <version>1.7.25</version> </dependency> </dependencies> <build> <plugins> <!-- 编译插件 --> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-compiler-plugin</artifactId> <version>3.6.0</version> <configuration> <source>1.8</source> <target>1.8</target> <encoding>UTF-8</encoding> </configuration> </plugin> <!-- scala编译插件 --> <plugin> <groupId>net.alchim31.maven</groupId> <artifactId>scala-maven-plugin</artifactId> <version>3.1.6</version> <configuration> <scalaCompatVersion>2.11</scalaCompatVersion> <scalaVersion>2.11.12</scalaVersion> <encoding>UTF-8</encoding> <addScalacArgs>-target:jvm-1.8</addScalacArgs> </configuration> <executions> <execution> <id>compile-scala</id> <phase>compile</phase> <goals> <goal>add-source</goal> <goal>compile</goal> </goals> </execution> <execution> <id>test-compile-scala</id> <phase>test-compile</phase> <goals> <goal>add-source</goal> <goal>testCompile</goal> </goals> </execution> </executions> </plugin> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-assembly-plugin</artifactId> <version>2.6</version> <configuration> <descriptorRefs> <descriptorRef>jar-with-dependencies</descriptorRef> </descriptorRefs> <archive> <manifest> <!-- 可以设置jar包的入口类(可选) --> <mainClass>MySqlBinlogSourceExample</mainClass> </manifest> </archive> </configuration> <executions> <execution> <id>make-assembly</id> <phase>package</phase> <goals> <goal>single</goal> </goals> </execution> </executions> </plugin> </plugins> </build> </project>
2、保存代码
package modelExport; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.regression.GbdtRegPredictBatchOp; import com.alibaba.alink.operator.batch.regression.GbdtRegTrainBatchOp; import com.alibaba.alink.operator.batch.sink.AkSinkBatchOp; import com.alibaba.alink.operator.batch.sink.CsvSinkBatchOp; import com.alibaba.alink.operator.batch.source.AkSourceBatchOp; import com.alibaba.alink.operator.batch.source.MemSourceBatchOp; import org.apache.flink.types.Row; import java.util.Arrays; import java.util.List; /** * @program: FlinkSql * @description: * @author: yang * @create: 2021-06-24 09:45 */ public class testGbdtRegPredictStreamOpExport { public static void main(String[] args) throws Exception { //1、数据处理 List <Row> df = Arrays.asList( Row.of(1.0, "A", 0, 0, 0), Row.of(2.0, "B", 1, 1, 0), Row.of(3.0, "C", 2, 2, 1), Row.of(4.0, "D", 3, 3, 1) ); BatchOperator <?> batchSource = new MemSourceBatchOp( df, new String[]{"f0", "f1", "f2", "f3", "label"}); //2、训练 BatchOperator <?> trainOp = new GbdtRegTrainBatchOp() .setLearningRate(1.0) .setNumTrees(3) .setMinSamplesPerLeaf(1) .setLabelCol("label") .setFeatureCols("f0", "f1", "f2", "f3") .linkFrom(batchSource); //3、预测 BatchOperator <?> predictBatchOp = new GbdtRegPredictBatchOp() .setPredictionCol("pred"); //4、pipline设置,进行数据预测 BatchOperator<?> resultOp = predictBatchOp.linkFrom(trainOp, batchSource); resultOp.print(); //5、结果集保存至Csv文件 CsvSinkBatchOp csvSink = new CsvSinkBatchOp(); csvSink.setFilePath("E:\\Flink\\FlinkSql\\FlinkML\\model\\testGbdtRegPredictStreamOp"); predictBatchOp.link(csvSink); //6、保存模型 AkSinkBatchOp csvSink1 = new AkSinkBatchOp(); csvSink1.setFilePath("E:\\Flink\\FlinkSql\\FlinkML\\model\\testGbdtRegPredictStreamOpModel"); trainOp.link(csvSink1); BatchOperator.execute(); } }
3、读取模型代码
package modelImport; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.regression.GbdtRegPredictBatchOp; import com.alibaba.alink.operator.batch.regression.GbdtRegTrainBatchOp; import com.alibaba.alink.operator.batch.sink.CsvSinkBatchOp; import com.alibaba.alink.operator.batch.source.AkSourceBatchOp; import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp; import com.alibaba.alink.operator.batch.source.MemSourceBatchOp; import com.alibaba.alink.pipeline.classification.RandomForestClassifier; import com.alibaba.alink.pipeline.tuning.*; import org.apache.flink.types.Row; import java.util.Arrays; import java.util.List; /** * @program: FlinkSql * @description: * @author: yang * @create: 2021-06-24 09:45 */ public class testGbdtRegPredictStreamOpImport { public static void main(String[] args) throws Exception { //1、读取数据模型 String filePath = "E:\\Flink\\FlinkSql\\FlinkML\\model\\testGbdtRegPredictStreamOpModel"; AkSourceBatchOp trainOp = new AkSourceBatchOp().setFilePath(filePath); //2、测试数据 List <Row> df = Arrays.asList( Row.of(6.0, "A", 3, 3), Row.of(10.0, "A", 1, 1), Row.of(8.0, "C", 2, 2), Row.of(9.0, "D", 3, 3) ); BatchOperator <?> batchSource = new MemSourceBatchOp( df, new String[]{"f0", "f1", "f2", "f3"}); //3、预测初始化 BatchOperator <?> predictBatchOp = new GbdtRegPredictBatchOp() .setPredictionCol("pred"); //4、数据预测 BatchOperator<?> resultOp = predictBatchOp.linkFrom(trainOp, batchSource); System.out.println(">>>>>>>>>>>>>>预测结果数据>>>>>>>>>>>>"); resultOp.print(); } }
本文来自博客园,作者:小白啊小白,Fighting,转载请注明原文链接:https://www.cnblogs.com/ywjfx/p/14928169.html