Spark2 ML包之决策树分类Decision tree classifier详细解说
所用数据源,请参考本人博客http://www.cnblogs.com/wwxbi/p/6063613.html
1.导入包
import org.apache.spark.sql.SparkSession import org.apache.spark.sql.Dataset import org.apache.spark.sql.Row import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Column import org.apache.spark.sql.DataFrameReader import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.Encoder import org.apache.spark.sql.DataFrameStatFunctions import org.apache.spark.sql.functions._ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.DecisionTreeClassificationModel import org.apache.spark.ml.classification.DecisionTreeClassifier import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator import org.apache.spark.ml.feature.VectorAssembler import org.apache.spark.ml.feature.StringIndexer import org.apache.spark.ml.feature.IndexToString import org.apache.spark.ml.feature.VectorIndexer import org.apache.spark.ml.feature.VectorSlicer
2.加载数据源
val spark = SparkSession.builder().appName("Spark decision tree classifier").config("spark.some.config.option", "some-value").getOrCreate() // For implicit conversions like converting RDDs to DataFrames import spark.implicits._ // 这里仅仅是示例数据,数据源,请参考本人博客http://www.cnblogs.com/wwxbi/p/6063613.html val dataList: List[(Double, String, Double, Double, String, Double, Double, Double, Double)] = List( (0, "male", 37, 10, "no", 3, 18, 7, 4), (0, "female", 27, 4, "no", 4, 14, 6, 4), (0, "female", 32, 15, "yes", 1, 12, 1, 4), (0, "male", 57, 15, "yes", 5, 18, 6, 5), (0, "male", 22, 0.75, "no", 2, 17, 6, 3), (0, "female", 32, 1.5, "no", 2, 17, 5, 5)) val data = dataList.toDF("affairs", "gender", "age", "yearsmarried", "children", "religiousness", "education", "occupation", "rating") data.createOrReplaceTempView("data") // 字符类型转换成数值 val labelWhere = "case when affairs=0 then 0 else cast(1 as double) end as label" val genderWhere = "case when gender='female' then 0 else cast(1 as double) end as gender" val childrenWhere = "case when children='no' then 0 else cast(1 as double) end as children" val dataLabelDF = spark.sql(s"select $labelWhere, $genderWhere,age,yearsmarried,$childrenWhere,religiousness,education,occupation,rating from data")
3.创建决策树模型
val featuresArray = Array("gender", "age", "yearsmarried", "children", "religiousness", "education", "occupation", "rating") // 字段转换成特征向量 val assembler = new VectorAssembler().setInputCols(featuresArray).setOutputCol("features") val vecDF: DataFrame = assembler.transform(dataLabelDF) vecDF.show(10, truncate = false) // 索引标签,将元数据添加到标签列中 val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(vecDF) labelIndexer.transform(vecDF).show(10, truncate = false) // 自动识别分类的特征,并对它们进行索引 // 具有大于5个不同的值的特征被视为连续。 val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(5).fit(vecDF) featureIndexer.transform(vecDF).show(10, truncate = false) // 将数据分为训练和测试集(30%进行测试) val Array(trainingData, testData) = vecDF.randomSplit(Array(0.7, 0.3)) // 训练决策树模型 val dt = new DecisionTreeClassifier() .setLabelCol("indexedLabel") .setFeaturesCol("indexedFeatures") .setImpurity("entropy") // 不纯度 .setMaxBins(100) // 离散化"连续特征"的最大划分数 .setMaxDepth(5) // 树的最大深度 .setMinInfoGain(0.01) //一个节点分裂的最小信息增益,值为[0,1] .setMinInstancesPerNode(10) //每个节点包含的最小样本数 .setSeed(123456) // 将索引标签转换回原始标签 val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels) // Chain indexers and tree in a Pipeline. val pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, dt, labelConverter)) // Train model. This also runs the indexers. val model = pipeline.fit(trainingData) // 作出预测 val predictions = model.transform(testData) // 选择几个示例行展示 predictions.select("predictedLabel", "label", "features").show(10, truncate = false) // 选择(预测标签,实际标签),并计算测试误差。 val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) println("Test Error = " + (1.0 - accuracy)) // 这里的stages(2)中的“2”对应pipeline中的“dt”,将model强制转换为DecisionTreeClassificationModel类型 val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel] treeModel.getLabelCol treeModel.getFeaturesCol treeModel.featureImportances treeModel.getPredictionCol treeModel.getProbabilityCol treeModel.numClasses treeModel.numFeatures treeModel.depth treeModel.numNodes treeModel.getImpurity treeModel.getMaxBins treeModel.getMaxDepth treeModel.getMaxMemoryInMB treeModel.getMinInfoGain treeModel.getMinInstancesPerNode println("Learned classification tree model:\n" + treeModel.toDebugString)
4.代码执行结果
val data = dataList.toDF("affairs", "gender", "age", "yearsmarried", "children", "religiousness", "education", "occupation", "rating") data.show(10, truncate = false) +-------+------+----+------------+--------+-------------+---------+----------+------+ |affairs|gender|age |yearsmarried|children|religiousness|education|occupation|rating| +-------+------+----+------------+--------+-------------+---------+----------+------+ |0.0 |male |37.0|10.0 |no |3.0 |18.0 |7.0 |4.0 | |0.0 |female|27.0|4.0 |no |4.0 |14.0 |6.0 |4.0 | |0.0 |female|32.0|15.0 |yes |1.0 |12.0 |1.0 |4.0 | |0.0 |male |57.0|15.0 |yes |5.0 |18.0 |6.0 |5.0 | |0.0 |male |22.0|0.75 |no |2.0 |17.0 |6.0 |3.0 | |0.0 |female|32.0|1.5 |no |2.0 |17.0 |5.0 |5.0 | |0.0 |female|22.0|0.75 |no |2.0 |12.0 |1.0 |3.0 | |0.0 |male |57.0|15.0 |yes |2.0 |14.0 |4.0 |4.0 | |0.0 |female|32.0|15.0 |yes |4.0 |16.0 |1.0 |2.0 | |0.0 |male |22.0|1.5 |no |4.0 |14.0 |4.0 |5.0 | +-------+------+----+------------+--------+-------------+---------+----------+------+ only showing top 10 rows data.createOrReplaceTempView("data") // 字符类型转换成数值 val labelWhere = "case when affairs=0 then 0 else cast(1 as double) end as label" val genderWhere = "case when gender='female' then 0 else cast(1 as double) end as gender" val childrenWhere = "case when children='no' then 0 else cast(1 as double) end as children" val dataLabelDF = spark.sql(s"select $labelWhere, $genderWhere,age,yearsmarried,$childrenWhere,religiousness,education,occupation,rating from data") dataLabelDF.show(10, truncate = false) +-----+------+----+------------+--------+-------------+---------+----------+------+ |label|gender|age |yearsmarried|children|religiousness|education|occupation|rating| +-----+------+----+------------+--------+-------------+---------+----------+------+ |0.0 |1.0 |37.0|10.0 |0.0 |3.0 |18.0 |7.0 |4.0 | |0.0 |0.0 |27.0|4.0 |0.0 |4.0 |14.0 |6.0 |4.0 | |0.0 |0.0 |32.0|15.0 |1.0 |1.0 |12.0 |1.0 |4.0 | |0.0 |1.0 |57.0|15.0 |1.0 |5.0 |18.0 |6.0 |5.0 | |0.0 |1.0 |22.0|0.75 |0.0 |2.0 |17.0 |6.0 |3.0 | |0.0 |0.0 |32.0|1.5 |0.0 |2.0 |17.0 |5.0 |5.0 | |0.0 |0.0 |22.0|0.75 |0.0 |2.0 |12.0 |1.0 |3.0 | |0.0 |1.0 |57.0|15.0 |1.0 |2.0 |14.0 |4.0 |4.0 | |0.0 |0.0 |32.0|15.0 |1.0 |4.0 |16.0 |1.0 |2.0 | |0.0 |1.0 |22.0|1.5 |0.0 |4.0 |14.0 |4.0 |5.0 | +-----+------+----+------------+--------+-------------+---------+----------+------+ only showing top 10 rows val featuresArray = Array("gender", "age", "yearsmarried", "children", "religiousness", "education", "occupation", "rating") // 字段转换成特征向量 val assembler = new VectorAssembler().setInputCols(featuresArray).setOutputCol("features") val vecDF: DataFrame = assembler.transform(dataLabelDF) vecDF.show(10, truncate = false) +-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+ |label|gender|age |yearsmarried|children|religiousness|education|occupation|rating|features | +-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+ |0.0 |1.0 |37.0|10.0 |0.0 |3.0 |18.0 |7.0 |4.0 |[1.0,37.0,10.0,0.0,3.0,18.0,7.0,4.0]| |0.0 |0.0 |27.0|4.0 |0.0 |4.0 |14.0 |6.0 |4.0 |[0.0,27.0,4.0,0.0,4.0,14.0,6.0,4.0] | |0.0 |0.0 |32.0|15.0 |1.0 |1.0 |12.0 |1.0 |4.0 |[0.0,32.0,15.0,1.0,1.0,12.0,1.0,4.0]| |0.0 |1.0 |57.0|15.0 |1.0 |5.0 |18.0 |6.0 |5.0 |[1.0,57.0,15.0,1.0,5.0,18.0,6.0,5.0]| |0.0 |1.0 |22.0|0.75 |0.0 |2.0 |17.0 |6.0 |3.0 |[1.0,22.0,0.75,0.0,2.0,17.0,6.0,3.0]| |0.0 |0.0 |32.0|1.5 |0.0 |2.0 |17.0 |5.0 |5.0 |[0.0,32.0,1.5,0.0,2.0,17.0,5.0,5.0] | |0.0 |0.0 |22.0|0.75 |0.0 |2.0 |12.0 |1.0 |3.0 |[0.0,22.0,0.75,0.0,2.0,12.0,1.0,3.0]| |0.0 |1.0 |57.0|15.0 |1.0 |2.0 |14.0 |4.0 |4.0 |[1.0,57.0,15.0,1.0,2.0,14.0,4.0,4.0]| |0.0 |0.0 |32.0|15.0 |1.0 |4.0 |16.0 |1.0 |2.0 |[0.0,32.0,15.0,1.0,4.0,16.0,1.0,2.0]| |0.0 |1.0 |22.0|1.5 |0.0 |4.0 |14.0 |4.0 |5.0 |[1.0,22.0,1.5,0.0,4.0,14.0,4.0,5.0] | +-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+ only showing top 10 rows // 索引标签,将元数据添加到标签列中 val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(vecDF) labelIndexer.transform(vecDF).show(10, truncate = false) +-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+------------+ |label|gender|age |yearsmarried|children|religiousness|education|occupation|rating|features |indexedLabel| +-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+------------+ |0.0 |1.0 |37.0|10.0 |0.0 |3.0 |18.0 |7.0 |4.0 |[1.0,37.0,10.0,0.0,3.0,18.0,7.0,4.0]|0.0 | |0.0 |0.0 |27.0|4.0 |0.0 |4.0 |14.0 |6.0 |4.0 |[0.0,27.0,4.0,0.0,4.0,14.0,6.0,4.0] |0.0 | |0.0 |0.0 |32.0|15.0 |1.0 |1.0 |12.0 |1.0 |4.0 |[0.0,32.0,15.0,1.0,1.0,12.0,1.0,4.0]|0.0 | |0.0 |1.0 |57.0|15.0 |1.0 |5.0 |18.0 |6.0 |5.0 |[1.0,57.0,15.0,1.0,5.0,18.0,6.0,5.0]|0.0 | |0.0 |1.0 |22.0|0.75 |0.0 |2.0 |17.0 |6.0 |3.0 |[1.0,22.0,0.75,0.0,2.0,17.0,6.0,3.0]|0.0 | |0.0 |0.0 |32.0|1.5 |0.0 |2.0 |17.0 |5.0 |5.0 |[0.0,32.0,1.5,0.0,2.0,17.0,5.0,5.0] |0.0 | |0.0 |0.0 |22.0|0.75 |0.0 |2.0 |12.0 |1.0 |3.0 |[0.0,22.0,0.75,0.0,2.0,12.0,1.0,3.0]|0.0 | |0.0 |1.0 |57.0|15.0 |1.0 |2.0 |14.0 |4.0 |4.0 |[1.0,57.0,15.0,1.0,2.0,14.0,4.0,4.0]|0.0 | |0.0 |0.0 |32.0|15.0 |1.0 |4.0 |16.0 |1.0 |2.0 |[0.0,32.0,15.0,1.0,4.0,16.0,1.0,2.0]|0.0 | |0.0 |1.0 |22.0|1.5 |0.0 |4.0 |14.0 |4.0 |5.0 |[1.0,22.0,1.5,0.0,4.0,14.0,4.0,5.0] |0.0 | +-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+------------+ only showing top 10 rows // 自动识别分类的特征,并对它们进行索引 // 具有大于5个不同的值的特征被视为连续。 val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(5).fit(vecDF) featureIndexer.transform(vecDF).show(10, truncate = false) featureIndexer.transform(vecDF).show(10, truncate = false) +-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+------------------------------------+ |label|gender|age |yearsmarried|children|religiousness|education|occupation|rating|features |indexedFeatures | +-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+------------------------------------+ |0.0 |1.0 |37.0|10.0 |0.0 |3.0 |18.0 |7.0 |4.0 |[1.0,37.0,10.0,0.0,3.0,18.0,7.0,4.0]|[1.0,37.0,10.0,0.0,2.0,18.0,7.0,3.0]| |0.0 |0.0 |27.0|4.0 |0.0 |4.0 |14.0 |6.0 |4.0 |[0.0,27.0,4.0,0.0,4.0,14.0,6.0,4.0] |[0.0,27.0,4.0,0.0,3.0,14.0,6.0,3.0] | |0.0 |0.0 |32.0|15.0 |1.0 |1.0 |12.0 |1.0 |4.0 |[0.0,32.0,15.0,1.0,1.0,12.0,1.0,4.0]|[0.0,32.0,15.0,1.0,0.0,12.0,1.0,3.0]| |0.0 |1.0 |57.0|15.0 |1.0 |5.0 |18.0 |6.0 |5.0 |[1.0,57.0,15.0,1.0,5.0,18.0,6.0,5.0]|[1.0,57.0,15.0,1.0,4.0,18.0,6.0,4.0]| |0.0 |1.0 |22.0|0.75 |0.0 |2.0 |17.0 |6.0 |3.0 |[1.0,22.0,0.75,0.0,2.0,17.0,6.0,3.0]|[1.0,22.0,0.75,0.0,1.0,17.0,6.0,2.0]| |0.0 |0.0 |32.0|1.5 |0.0 |2.0 |17.0 |5.0 |5.0 |[0.0,32.0,1.5,0.0,2.0,17.0,5.0,5.0] |[0.0,32.0,1.5,0.0,1.0,17.0,5.0,4.0] | |0.0 |0.0 |22.0|0.75 |0.0 |2.0 |12.0 |1.0 |3.0 |[0.0,22.0,0.75,0.0,2.0,12.0,1.0,3.0]|[0.0,22.0,0.75,0.0,1.0,12.0,1.0,2.0]| |0.0 |1.0 |57.0|15.0 |1.0 |2.0 |14.0 |4.0 |4.0 |[1.0,57.0,15.0,1.0,2.0,14.0,4.0,4.0]|[1.0,57.0,15.0,1.0,1.0,14.0,4.0,3.0]| |0.0 |0.0 |32.0|15.0 |1.0 |4.0 |16.0 |1.0 |2.0 |[0.0,32.0,15.0,1.0,4.0,16.0,1.0,2.0]|[0.0,32.0,15.0,1.0,3.0,16.0,1.0,1.0]| |0.0 |1.0 |22.0|1.5 |0.0 |4.0 |14.0 |4.0 |5.0 |[1.0,22.0,1.5,0.0,4.0,14.0,4.0,5.0] |[1.0,22.0,1.5,0.0,3.0,14.0,4.0,4.0] | +-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+------------------------------------+ only showing top 10 rows // 将数据分为训练和测试集(30%进行测试) val Array(trainingData, testData) = vecDF.randomSplit(Array(0.7, 0.3)) // 训练决策树模型 val dt = new DecisionTreeClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setImpurity("entropy").setMaxBins(100).setMaxDepth(5).setMinInfoGain(0.01).setMinInstancesPerNode(10).setSeed(123456) //.setLabelCol("indexedLabel") //.setFeaturesCol("indexedFeatures") //.setImpurity("entropy") // 不纯度 //.setMaxBins(100) // 离散化"连续特征"的最大划分数 //.setMaxDepth(5) // 树的最大深度 //.setMinInfoGain(0.01) //一个节点分裂的最小信息增益,值为[0,1] //.setMinInstancesPerNode(10) //每个节点包含的最小样本数 //.setSeed(123456) // 将索引标签转换回原始标签 val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels) // Chain indexers and tree in a Pipeline. val pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, dt, labelConverter)) // Train model. This also runs the indexers. val model = pipeline.fit(trainingData) // 作出预测 val predictions = model.transform(testData) // 选择几个示例行展示 predictions.select("predictedLabel", "label", "features").show(10, truncate = false) +--------------+-----+-------------------------------------+ |predictedLabel|label|features | +--------------+-----+-------------------------------------+ |0.0 |0.0 |[0.0,22.0,0.125,0.0,2.0,14.0,4.0,5.0]| |0.0 |0.0 |[0.0,22.0,0.417,0.0,1.0,17.0,6.0,4.0]| |0.0 |0.0 |[0.0,22.0,0.75,0.0,2.0,18.0,6.0,5.0] | |0.0 |0.0 |[0.0,22.0,0.75,0.0,3.0,16.0,1.0,5.0] | |0.0 |0.0 |[0.0,22.0,0.75,0.0,4.0,16.0,1.0,5.0] | |0.0 |0.0 |[0.0,22.0,1.5,0.0,1.0,14.0,1.0,5.0] | |0.0 |0.0 |[0.0,22.0,1.5,0.0,2.0,14.0,1.0,5.0] | |0.0 |0.0 |[0.0,22.0,1.5,0.0,2.0,16.0,5.0,5.0] | |0.0 |0.0 |[0.0,22.0,1.5,0.0,2.0,16.0,5.0,5.0] | |0.0 |0.0 |[0.0,22.0,1.5,0.0,2.0,17.0,5.0,4.0] | +--------------+-----+-------------------------------------+ // 选择(预测标签,实际标签),并计算测试误差。 val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) accuracy: Double = 0.6972972972972973 println("Test Error = " + (1.0 - accuracy)) Test Error = 0.3027027027027027 // 这里的stages(2)中的“2”对应pipeline中的“dt”,将model强制转换为DecisionTreeClassificationModel类型 val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel] DecisionTreeClassificationModel (uid=dtc_b950f91d35f8) of depth 5 with 43 nodes treeModel.getLabelCol String = indexedLabel treeModel.getFeaturesCol String = indexedFeatures treeModel.featureImportances Vector = (8,[0,1,2,4,5,6,7],[0.012972759843658999,0.1075317063921102,0.11654682273543511,0.17869552275855793,0.07532637852021348,0.27109893303920024,0.237827 876710824]) treeModel.getPredictionCol String = prediction treeModel.getProbabilityCol String = probability treeModel.numClasses Int = 2 treeModel.numFeatures Int = 8 treeModel.depth Int = 5 treeModel.numNodes Int = 43 treeModel.getImpurity String = entropy treeModel.getMaxBins Int = 100 treeModel.getMaxDepth Int = 5 treeModel.getMaxMemoryInMB Int = 256 treeModel.getMinInfoGain Double = 0.01 treeModel.getMinInstancesPerNode Int = 10 // 查看决策树 println("Learned classification tree model:\n" + treeModel.toDebugString) Learned classification tree model: DecisionTreeClassificationModel (uid=dtc_b950f91d35f8) of depth 5 with 43 nodes // 例如“feature 7 in {0.0,1.0,2.0}”中的“{0.0,1.0,2.0}” // 具体解释请参考本人博客http://www.cnblogs.com/wwxbi/p/6125493.html“VectorIndexer自动识别分类的特征,并对它们进行索引” If (feature 7 in {0.0,1.0,2.0}) If (feature 7 in {0.0,2.0}) If (feature 4 in {0.0,4.0}) Predict: 1.0 Else (feature 4 not in {0.0,4.0}) If (feature 1 <= 32.0) If (feature 1 <= 27.0) Predict: 0.0 Else (feature 1 > 27.0) Predict: 1.0 Else (feature 1 > 32.0) If (feature 5 <= 16.0) Predict: 0.0 Else (feature 5 > 16.0) Predict: 0.0 Else (feature 7 not in {0.0,2.0}) If (feature 4 in {0.0,1.0,3.0,4.0}) If (feature 0 in {0.0}) If (feature 2 <= 7.0) Predict: 0.0 Else (feature 2 > 7.0) Predict: 0.0 Else (feature 0 not in {0.0}) Predict: 0.0 Else (feature 4 not in {0.0,1.0,3.0,4.0}) Predict: 1.0 Else (feature 7 not in {0.0,1.0,2.0}) If (feature 2 <= 4.0) If (feature 6 <= 3.0) If (feature 6 <= 1.0) Predict: 0.0 Else (feature 6 > 1.0) Predict: 0.0 Else (feature 6 > 3.0) If (feature 5 <= 16.0) If (feature 2 <= 0.75) Predict: 0.0 Else (feature 2 > 0.75) Predict: 0.0 Else (feature 5 > 16.0) If (feature 7 in {4.0}) Predict: 0.0 Else (feature 7 not in {4.0}) Predict: 0.0 Else (feature 2 > 4.0) If (feature 6 <= 3.0) If (feature 4 in {0.0,1.0,2.0}) Predict: 0.0 Else (feature 4 not in {0.0,1.0,2.0}) If (feature 7 in {4.0}) Predict: 0.0 Else (feature 7 not in {4.0}) Predict: 0.0 Else (feature 6 > 3.0) If (feature 4 in {0.0,2.0,3.0,4.0}) If (feature 6 <= 4.0) Predict: 0.0 Else (feature 6 > 4.0) Predict: 0.0 Else (feature 4 not in {0.0,2.0,3.0,4.0}) If (feature 1 <= 37.0) Predict: 1.0 Else (feature 1 > 37.0) Predict: 0.0