在使用Pipeline串联多个stage时model和非model的区别
train.csv数据:
id,name,age,sex
1,lyy,20,F
2,rdd,20,M
3,nyc,18,M
4,mzy,10,M
数据读取:
1 SparkSession spark = SparkSession.builder().enableHiveSupport() 2 .getOrCreate(); 3 Dataset<Row> dataset = spark 4 .read() 5 .format("org.apache.spark.sql.execution.datasources.csv.CSVFileFormat") 6 .option("header", true) 7 .option("inferSchema", true) 8 .option("delimiter", ",") 9 //.load("file:///E:/git/bigdata_sparkIDE/spark-ide/workspace/test/SparkMLTest/SanFranciscoCrime/document/kaggle-旧金山犯罪分类/train-new.csv") //PreProcess1 10 .load("file:///E:/git/bigdata_sparkIDE/spark-ide/workspace/test/SparkMLTest/DataPreprocessing/document/train.csv") //PreProcess2 11 .persist();
1 public static void PreProcess2(Dataset<Row> data) { 2 3 data.printSchema(); 4 // 重新索引标签值 5 StringIndexerModel labelIndexer = new StringIndexer() 6 .setInputCol("sex") 7 .setOutputCol("label") 8 .fit(data); 9 10 StringIndexerModel nameIndexer = new StringIndexer() 11 .setInputCol("name") 12 .setOutputCol("namenum") 13 .fit(data); 14 15 16 /* 会报错:Exception in thread "main" java.lang.IllegalArgumentException: Field "namenum" does not exist. 17 * 原因是:Model类型调用fit时,要求数据集中必须包含InputCol所指定的列名 18 * 不会将Pipeline某个stage的输出作为InputCol,即使那个stage的OutputCol指定的列名与其相同也不行 19 * StringIndexerModel name1Indexer = new StringIndexer() 20 .setInputCol("namenum") 21 .setOutputCol("namenum1") 22 .fit(data);*/ 23 24 25 /* 错误原因StringIndexerModel错误一样,features并不是data的列 26 * VectorIndexerModel featureIndexer = new VectorIndexer() 27 .setInputCol("features") 28 .setOutputCol("indexfeatures") 29 .setMaxCategories(4) 30 .fit(data);*/ 31 32 //成功 33 //原因说明:非model时,转换器不会调用fit,而会使用Pipeline某个stage的输出作为InputCol 34 //由于stage[2]即 assembler已经生成features,故而该处直接使用; 35 //但是该类型时不能单独使用,必须依赖Pipeline 36 VectorIndexer featureIndexer = new VectorIndexer() 37 .setInputCol("features") 38 .setOutputCol("indexfeatures") 39 .setMaxCategories(4); 40 41 //由上述分析可知,该处输入的列可以是多个stage的输出组成,因为VectorAssembler非model 42 //因此可以使用中间生成结果,且可以使用多个 43 VectorAssembler assembler = new VectorAssembler() 44 .setInputCols("id,namenum,age".split(",")) 45 .setOutputCol("features"); 46 47 //这里的stage的顺序很重要,一定按照依赖关系顺序放入,如下顺序就会报错: 48 //Exception in thread "main" java.lang.IllegalArgumentException: Field "features" does not exist. 49 //Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {labelIndexer,nameIndexer,featureIndexer,assembler}); 50 51 //将featureIndexer放到assembler即可 52 Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {labelIndexer,nameIndexer,assembler,featureIndexer}); 53 54 // Train model. This also runs the indexers. 55 PipelineModel model = pipeline.fit(data); 56 57 // Make predictions. 58 Dataset<Row> result = model.transform(data); 59 60 result.show(10, false); 61 62 }
root
|-- id: integer (nullable = true)
|-- name: string (nullable = true)
|-- age: integer (nullable = true)
|-- sex: string (nullable = true)
+---+----+---+---+-----+-------+--------------+-------------+
|id |name|age|sex|label|namenum|features |indexfeatures|
+---+----+---+---+-----+-------+--------------+-------------+
|1 |lyy |20 |F |1.0 |1.0 |[1.0,1.0,20.0]|[0.0,1.0,2.0]|
|2 |rdd |20 |M |0.0 |2.0 |[2.0,2.0,20.0]|[1.0,2.0,2.0]|
|3 |nyc |18 |M |0.0 |0.0 |[3.0,0.0,18.0]|[2.0,0.0,1.0]|
|4 |mzy |10 |M |0.0 |3.0 |[4.0,3.0,10.0]|[3.0,3.0,0.0]|
+---+----+---+---+-----+-------+--------------+-------------+
综上分析,可以将原有代码做一简化:
1 public static void PreProcess2(Dataset<Row> data) { 2 3 data.printSchema(); 4 // 重新索引标签值 5 StringIndexer labelIndexer = new StringIndexer() 6 .setInputCol("sex") 7 .setOutputCol("label"); 8 9 StringIndexer nameIndexer = new StringIndexer() 10 .setInputCol("name") 11 .setOutputCol("namenum"); 12 13 VectorIndexer featureIndexer = new VectorIndexer() 14 .setInputCol("features") 15 .setOutputCol("indexfeatures") 16 .setMaxCategories(4); 17 18 19 VectorAssembler assembler = new VectorAssembler() 20 .setInputCols("id,namenum,age".split(",")) 21 .setOutputCol("features"); 22 23 Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {labelIndexer,nameIndexer,assembler,featureIndexer}); 24 25 // Train model. This also runs the indexers. 26 PipelineModel model = pipeline.fit(data); //以这里的data为基准数据 27 28 // Make predictions. 29 Dataset<Row> result = model.transform(data); 30 31 result.show(10, false); 32 33 }
运行结果:
root |-- id: integer (nullable = true) |-- name: string (nullable = true) |-- age: integer (nullable = true) |-- sex: string (nullable = true) +---+----+---+---+-----+-------+--------------+-------------+ |id |name|age|sex|label|namenum|features |indexfeatures| +---+----+---+---+-----+-------+--------------+-------------+ |1 |lyy |20 |F |1.0 |1.0 |[1.0,1.0,20.0]|[0.0,1.0,2.0]| |2 |rdd |20 |M |0.0 |2.0 |[2.0,2.0,20.0]|[1.0,2.0,2.0]| |3 |nyc |18 |M |0.0 |0.0 |[3.0,0.0,18.0]|[2.0,0.0,1.0]| |4 |mzy |10 |M |0.0 |3.0 |[4.0,3.0,10.0]|[3.0,3.0,0.0]| +---+----+---+---+-----+-------+--------------+-------------+