Spark2 Random Forests 随机森林
随机森林是决策树的集合。 随机森林结合许多决策树,以减少过度拟合的风险。 spark.ml实现支持随机森林,使用连续和分类特征,做二分类和多分类以及回归。
导入包
import org.apache.spark.sql.SparkSession import org.apache.spark.sql.Dataset import org.apache.spark.sql.Row import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Column import org.apache.spark.sql.DataFrameReader import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.Encoder import org.apache.spark.sql.DataFrameStatFunctions import org.apache.spark.sql.functions._ import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.feature.{ IndexToString, StringIndexer, VectorIndexer } import org.apache.spark.ml.feature.VectorAssembler import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.{ RandomForestClassificationModel, RandomForestClassifier } import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator import org.apache.spark.ml.tuning.{ ParamGridBuilder, CrossValidator }
导入源数据
// affairs:一年来婚外情的频率 // gender:性别 // age:年龄 // yearsmarried:婚龄 // children:是否有小孩 // religiousness:宗教信仰程度(5分制,1分表示反对,5分表示非常信仰) // education:学历 // occupation:职业(逆向编号的戈登7种分类) // rating:对婚姻的自我评分(5分制,1表示非常不幸福,5表示非常幸福) val spark = SparkSession.builder().appName("Spark Random Forest Classifier").config("spark.some.config.option", "some-value").getOrCreate() // For implicit conversions like converting RDDs to DataFrames import spark.implicits._ val dataList: List[(Double, String, Double, Double, String, Double, Double, Double, Double)] = List( (0, "male", 37, 10, "no", 3, 18, 7, 4), (0, "female", 27, 4, "no", 4, 14, 6, 4), (0, "female", 32, 15, "yes", 1, 12, 1, 4), (0, "male", 57, 15, "yes", 5, 18, 6, 5), (0, "male", 22, 0.75, "no", 2, 17, 6, 3), (0, "female", 32, 1.5, "no", 2, 17, 5, 5), (0, "female", 22, 0.75, "no", 2, 12, 1, 3), (0, "male", 57, 15, "yes", 2, 14, 4, 4), (0, "female", 32, 15, "yes", 4, 16, 1, 2), (0, "male", 22, 1.5, "no", 4, 14, 4, 5), (0, "male", 37, 15, "yes", 2, 20, 7, 2), (0, "male", 27, 4, "yes", 4, 18, 6, 4), (0, "male", 47, 15, "yes", 5, 17, 6, 4), (0, "female", 22, 1.5, "no", 2, 17, 5, 4), (0, "female", 27, 4, "no", 4, 14, 5, 4), (0, "female", 37, 15, "yes", 1, 17, 5, 5), (0, "female", 37, 15, "yes", 2, 18, 4, 3), (0, "female", 22, 0.75, "no", 3, 16, 5, 4), (0, "female", 22, 1.5, "no", 2, 16, 5, 5), (0, "female", 27, 10, "yes", 2, 14, 1, 5), (0, "female", 22, 1.5, "no", 2, 16, 5, 5), (0, "female", 22, 1.5, "no", 2, 16, 5, 5), (0, "female", 27, 10, "yes", 4, 16, 5, 4), (0, "female", 32, 10, "yes", 3, 14, 1, 5), (0, "male", 37, 4, "yes", 2, 20, 6, 4), (0, "female", 22, 1.5, "no", 2, 18, 5, 5), (0, "female", 27, 7, "no", 4, 16, 1, 5), (0, "male", 42, 15, "yes", 5, 20, 6, 4), (0, "male", 27, 4, "yes", 3, 16, 5, 5), (0, "female", 27, 4, "yes", 3, 17, 5, 4), (0, "male", 42, 15, "yes", 4, 20, 6, 3), (0, "female", 22, 1.5, "no", 3, 16, 5, 5), (0, "male", 27, 0.417, "no", 4, 17, 6, 4), (0, "female", 42, 15, "yes", 5, 14, 5, 4), (0, "male", 32, 4, "yes", 1, 18, 6, 4), (0, "female", 22, 1.5, "no", 4, 16, 5, 3), (0, "female", 42, 15, "yes", 3, 12, 1, 4), (0, "female", 22, 4, "no", 4, 17, 5, 5), (0, "male", 22, 1.5, "yes", 1, 14, 3, 5), (0, "female", 22, 0.75, "no", 3, 16, 1, 5), (0, "male", 32, 10, "yes", 5, 20, 6, 5), (0, "male", 52, 15, "yes", 5, 18, 6, 3), (0, "female", 22, 0.417, "no", 5, 14, 1, 4), (0, "female", 27, 4, "yes", 2, 18, 6, 1), (0, "female", 32, 7, "yes", 5, 17, 5, 3), (0, "male", 22, 4, "no", 3, 16, 5, 5), (0, "female", 27, 7, "yes", 4, 18, 6, 5), (0, "female", 42, 15, "yes", 2, 18, 5, 4), (0, "male", 27, 1.5, "yes", 4, 16, 3, 5), (0, "male", 42, 15, "yes", 2, 20, 6, 4), (0, "female", 22, 0.75, "no", 5, 14, 3, 5), (0, "male", 32, 7, "yes", 2, 20, 6, 4), (0, "male", 27, 4, "yes", 5, 20, 6, 5), (0, "male", 27, 10, "yes", 4, 20, 6, 4), (0, "male", 22, 4, "no", 1, 18, 5, 5), (0, "female", 37, 15, "yes", 4, 14, 3, 1), (0, "male", 22, 1.5, "yes", 5, 16, 4, 4), (0, "female", 37, 15, "yes", 4, 17, 1, 5), (0, "female", 27, 0.75, "no", 4, 17, 5, 4), (0, "male", 32, 10, "yes", 4, 20, 6, 4), (0, "female", 47, 15, "yes", 5, 14, 7, 2), (0, "male", 37, 10, "yes", 3, 20, 6, 4), (0, "female", 22, 0.75, "no", 2, 16, 5, 5), (0, "male", 27, 4, "no", 2, 18, 4, 5), (0, "male", 32, 7, "no", 4, 20, 6, 4), (0, "male", 42, 15, "yes", 2, 17, 3, 5), (0, "male", 37, 10, "yes", 4, 20, 6, 4), (0, "female", 47, 15, "yes", 3, 17, 6, 5), (0, "female", 22, 1.5, "no", 5, 16, 5, 5), (0, "female", 27, 1.5, "no", 2, 16, 6, 4), (0, "female", 27, 4, "no", 3, 17, 5, 5), (0, "female", 32, 10, "yes", 5, 14, 4, 5), (0, "female", 22, 0.125, "no", 2, 12, 5, 5), (0, "male", 47, 15, "yes", 4, 14, 4, 3), (0, "male", 32, 15, "yes", 1, 14, 5, 5), (0, "male", 27, 7, "yes", 4, 16, 5, 5), (0, "female", 22, 1.5, "yes", 3, 16, 5, 5), (0, "male", 27, 4, "yes", 3, 17, 6, 5), (0, "female", 22, 1.5, "no", 3, 16, 5, 5), (0, "male", 57, 15, "yes", 2, 14, 7, 2), (0, "male", 17.5, 1.5, "yes", 3, 18, 6, 5), (0, "male", 57, 15, "yes", 4, 20, 6, 5), (0, "female", 22, 0.75, "no", 2, 16, 3, 4), (0, "male", 42, 4, "no", 4, 17, 3, 3), (0, "female", 22, 1.5, "yes", 4, 12, 1, 5), (0, "female", 22, 0.417, "no", 1, 17, 6, 4), (0, "female", 32, 15, "yes", 4, 17, 5, 5), (0, "female", 27, 1.5, "no", 3, 18, 5, 2), (0, "female", 22, 1.5, "yes", 3, 14, 1, 5), (0, "female", 37, 15, "yes", 3, 14, 1, 4), (0, "female", 32, 15, "yes", 4, 14, 3, 4), (0, "male", 37, 10, "yes", 2, 14, 5, 3), (0, "male", 37, 10, "yes", 4, 16, 5, 4), (0, "male", 57, 15, "yes", 5, 20, 5, 3), (0, "male", 27, 0.417, "no", 1, 16, 3, 4), (0, "female", 42, 15, "yes", 5, 14, 1, 5), (0, "male", 57, 15, "yes", 3, 16, 6, 1), (0, "male", 37, 10, "yes", 1, 16, 6, 4), (0, "male", 37, 15, "yes", 3, 17, 5, 5), (0, "male", 37, 15, "yes", 4, 20, 6, 5), (0, "female", 27, 10, "yes", 5, 14, 1, 5), (0, "male", 37, 10, "yes", 2, 18, 6, 4), (0, "female", 22, 0.125, "no", 4, 12, 4, 5), (0, "male", 57, 15, "yes", 5, 20, 6, 5), (0, "female", 37, 15, "yes", 4, 18, 6, 4), (0, "male", 22, 4, "yes", 4, 14, 6, 4), (0, "male", 27, 7, "yes", 4, 18, 5, 4), (0, "male", 57, 15, "yes", 4, 20, 5, 4), (0, "male", 32, 15, "yes", 3, 14, 6, 3), (0, "female", 22, 1.5, "no", 2, 14, 5, 4), (0, "female", 32, 7, "yes", 4, 17, 1, 5), (0, "female", 37, 15, "yes", 4, 17, 6, 5), (0, "female", 32, 1.5, "no", 5, 18, 5, 5), (0, "male", 42, 10, "yes", 5, 20, 7, 4), (0, "female", 27, 7, "no", 3, 16, 5, 4), (0, "male", 37, 15, "no", 4, 20, 6, 5), (0, "male", 37, 15, "yes", 4, 14, 3, 2), (0, "male", 32, 10, "no", 5, 18, 6, 4), (0, "female", 22, 0.75, "no", 4, 16, 1, 5), (0, "female", 27, 7, "yes", 4, 12, 2, 4), (0, "female", 27, 7, "yes", 2, 16, 2, 5), (0, "female", 42, 15, "yes", 5, 18, 5, 4), (0, "male", 42, 15, "yes", 4, 17, 5, 3), (0, "female", 27, 7, "yes", 2, 16, 1, 2), (0, "female", 22, 1.5, "no", 3, 16, 5, 5), (0, "male", 37, 15, "yes", 5, 20, 6, 5), (0, "female", 22, 0.125, "no", 2, 14, 4, 5), (0, "male", 27, 1.5, "no", 4, 16, 5, 5), (0, "male", 32, 1.5, "no", 2, 18, 6, 5), (0, "male", 27, 1.5, "no", 2, 17, 6, 5), (0, "female", 27, 10, "yes", 4, 16, 1, 3), (0, "male", 42, 15, "yes", 4, 18, 6, 5), (0, "female", 27, 1.5, "no", 2, 16, 6, 5), (0, "male", 27, 4, "no", 2, 18, 6, 3), (0, "female", 32, 10, "yes", 3, 14, 5, 3), (0, "female", 32, 15, "yes", 3, 18, 5, 4), (0, "female", 22, 0.75, "no", 2, 18, 6, 5), (0, "female", 37, 15, "yes", 2, 16, 1, 4), (0, "male", 27, 4, "yes", 4, 20, 5, 5), (0, "male", 27, 4, "no", 1, 20, 5, 4), (0, "female", 27, 10, "yes", 2, 12, 1, 4), (0, "female", 32, 15, "yes", 5, 18, 6, 4), (0, "male", 27, 7, "yes", 5, 12, 5, 3), (0, "male", 52, 15, "yes", 2, 18, 5, 4), (0, "male", 27, 4, "no", 3, 20, 6, 3), (0, "male", 37, 4, "yes", 1, 18, 5, 4), (0, "male", 27, 4, "yes", 4, 14, 5, 4), (0, "female", 52, 15, "yes", 5, 12, 1, 3), (0, "female", 57, 15, "yes", 4, 16, 6, 4), (0, "male", 27, 7, "yes", 1, 16, 5, 4), (0, "male", 37, 7, "yes", 4, 20, 6, 3), (0, "male", 22, 0.75, "no", 2, 14, 4, 3), (0, "male", 32, 4, "yes", 2, 18, 5, 3), (0, "male", 37, 15, "yes", 4, 20, 6, 3), (0, "male", 22, 0.75, "yes", 2, 14, 4, 3), (0, "male", 42, 15, "yes", 4, 20, 6, 3), (0, "female", 52, 15, "yes", 5, 17, 1, 1), (0, "female", 37, 15, "yes", 4, 14, 1, 2), (0, "male", 27, 7, "yes", 4, 14, 5, 3), (0, "male", 32, 4, "yes", 2, 16, 5, 5), (0, "female", 27, 4, "yes", 2, 18, 6, 5), (0, "female", 27, 4, "yes", 2, 18, 5, 5), (0, "male", 37, 15, "yes", 5, 18, 6, 5), (0, "female", 47, 15, "yes", 5, 12, 5, 4), (0, "female", 32, 10, "yes", 3, 17, 1, 4), (0, "female", 27, 1.5, "yes", 4, 17, 1, 2), (0, "female", 57, 15, "yes", 2, 18, 5, 2), (0, "female", 22, 1.5, "no", 4, 14, 5, 4), (0, "male", 42, 15, "yes", 3, 14, 3, 4), (0, "male", 57, 15, "yes", 4, 9, 2, 2), (0, "male", 57, 15, "yes", 4, 20, 6, 5), (0, "female", 22, 0.125, "no", 4, 14, 4, 5), (0, "female", 32, 10, "yes", 4, 14, 1, 5), (0, "female", 42, 15, "yes", 3, 18, 5, 4), (0, "female", 27, 1.5, "no", 2, 18, 6, 5), (0, "male", 32, 0.125, "yes", 2, 18, 5, 2), (0, "female", 27, 4, "no", 3, 16, 5, 4), (0, "female", 27, 10, "yes", 2, 16, 1, 4), (0, "female", 32, 7, "yes", 4, 16, 1, 3), (0, "female", 37, 15, "yes", 4, 14, 5, 4), (0, "female", 42, 15, "yes", 5, 17, 6, 2), (0, "male", 32, 1.5, "yes", 4, 14, 6, 5), (0, "female", 32, 4, "yes", 3, 17, 5, 3), (0, "female", 37, 7, "no", 4, 18, 5, 5), (0, "female", 22, 0.417, "yes", 3, 14, 3, 5), (0, "female", 27, 7, "yes", 4, 14, 1, 5), (0, "male", 27, 0.75, "no", 3, 16, 5, 5), (0, "male", 27, 4, "yes", 2, 20, 5, 5), (0, "male", 32, 10, "yes", 4, 16, 4, 5), (0, "male", 32, 15, "yes", 1, 14, 5, 5), (0, "male", 22, 0.75, "no", 3, 17, 4, 5), (0, "female", 27, 7, "yes", 4, 17, 1, 4), (0, "male", 27, 0.417, "yes", 4, 20, 5, 4), (0, "male", 37, 15, "yes", 4, 20, 5, 4), (0, "female", 37, 15, "yes", 2, 14, 1, 3), (0, "male", 22, 4, "yes", 1, 18, 5, 4), (0, "male", 37, 15, "yes", 4, 17, 5, 3), (0, "female", 22, 1.5, "no", 2, 14, 4, 5), (0, "male", 52, 15, "yes", 4, 14, 6, 2), (0, "female", 22, 1.5, "no", 4, 17, 5, 5), (0, "male", 32, 4, "yes", 5, 14, 3, 5), (0, "male", 32, 4, "yes", 2, 14, 3, 5), (0, "female", 22, 1.5, "no", 3, 16, 6, 5), (0, "male", 27, 0.75, "no", 2, 18, 3, 3), (0, "female", 22, 7, "yes", 2, 14, 5, 2), (0, "female", 27, 0.75, "no", 2, 17, 5, 3), (0, "female", 37, 15, "yes", 4, 12, 1, 2), (0, "female", 22, 1.5, "no", 1, 14, 1, 5), (0, "female", 37, 10, "no", 2, 12, 4, 4), (0, "female", 37, 15, "yes", 4, 18, 5, 3), (0, "female", 42, 15, "yes", 3, 12, 3, 3), (0, "male", 22, 4, "no", 2, 18, 5, 5), (0, "male", 52, 7, "yes", 2, 20, 6, 2), (0, "male", 27, 0.75, "no", 2, 17, 5, 5), (0, "female", 27, 4, "no", 2, 17, 4, 5), (0, "male", 42, 1.5, "no", 5, 20, 6, 5), (0, "male", 22, 1.5, "no", 4, 17, 6, 5), (0, "male", 22, 4, "no", 4, 17, 5, 3), (0, "female", 22, 4, "yes", 1, 14, 5, 4), (0, "male", 37, 15, "yes", 5, 20, 4, 5), (0, "female", 37, 10, "yes", 3, 16, 6, 3), (0, "male", 42, 15, "yes", 4, 17, 6, 5), (0, "female", 47, 15, "yes", 4, 17, 5, 5), (0, "male", 22, 1.5, "no", 4, 16, 5, 4), (0, "female", 32, 10, "yes", 3, 12, 1, 4), (0, "female", 22, 7, "yes", 1, 14, 3, 5), (0, "female", 32, 10, "yes", 4, 17, 5, 4), (0, "male", 27, 1.5, "yes", 2, 16, 2, 4), (0, "male", 37, 15, "yes", 4, 14, 5, 5), (0, "male", 42, 4, "yes", 3, 14, 4, 5), (0, "female", 37, 15, "yes", 5, 14, 5, 4), (0, "female", 32, 7, "yes", 4, 17, 5, 5), (0, "female", 42, 15, "yes", 4, 18, 6, 5), (0, "male", 27, 4, "no", 4, 18, 6, 4), (0, "male", 22, 0.75, "no", 4, 18, 6, 5), (0, "male", 27, 4, "yes", 4, 14, 5, 3), (0, "female", 22, 0.75, "no", 5, 18, 1, 5), (0, "female", 52, 15, "yes", 5, 9, 5, 5), (0, "male", 32, 10, "yes", 3, 14, 5, 5), (0, "female", 37, 15, "yes", 4, 16, 4, 4), (0, "male", 32, 7, "yes", 2, 20, 5, 4), (0, "female", 42, 15, "yes", 3, 18, 1, 4), (0, "male", 32, 15, "yes", 1, 16, 5, 5), (0, "male", 27, 4, "yes", 3, 18, 5, 5), (0, "female", 32, 15, "yes", 4, 12, 3, 4), (0, "male", 22, 0.75, "yes", 3, 14, 2, 4), (0, "female", 22, 1.5, "no", 3, 16, 5, 3), (0, "female", 42, 15, "yes", 4, 14, 3, 5), (0, "female", 52, 15, "yes", 3, 16, 5, 4), (0, "male", 37, 15, "yes", 5, 20, 6, 4), (0, "female", 47, 15, "yes", 4, 12, 2, 3), (0, "male", 57, 15, "yes", 2, 20, 6, 4), (0, "male", 32, 7, "yes", 4, 17, 5, 5), (0, "female", 27, 7, "yes", 4, 17, 1, 4), (0, "male", 22, 1.5, "no", 1, 18, 6, 5), (0, "female", 22, 4, "yes", 3, 9, 1, 4), (0, "female", 22, 1.5, "no", 2, 14, 1, 5), (0, "male", 42, 15, "yes", 2, 20, 6, 4), (0, "male", 57, 15, "yes", 4, 9, 2, 4), (0, "female", 27, 7, "yes", 2, 18, 1, 5), (0, "female", 22, 4, "yes", 3, 14, 1, 5), (0, "male", 37, 15, "yes", 4, 14, 5, 3), (0, "male", 32, 7, "yes", 1, 18, 6, 4), (0, "female", 22, 1.5, "no", 2, 14, 5, 5), (0, "female", 22, 1.5, "yes", 3, 12, 1, 3), (0, "male", 52, 15, "yes", 2, 14, 5, 5), (0, "female", 37, 15, "yes", 2, 14, 1, 1), (0, "female", 32, 10, "yes", 2, 14, 5, 5), (0, "male", 42, 15, "yes", 4, 20, 4, 5), (0, "female", 27, 4, "yes", 3, 18, 4, 5), (0, "male", 37, 15, "yes", 4, 20, 6, 5), (0, "male", 27, 1.5, "no", 3, 18, 5, 5), (0, "female", 22, 0.125, "no", 2, 16, 6, 3), (0, "male", 32, 10, "yes", 2, 20, 6, 3), (0, "female", 27, 4, "no", 4, 18, 5, 4), (0, "female", 27, 7, "yes", 2, 12, 5, 1), (0, "male", 32, 4, "yes", 5, 18, 6, 3), (0, "female", 37, 15, "yes", 2, 17, 5, 5), (0, "male", 47, 15, "no", 4, 20, 6, 4), (0, "male", 27, 1.5, "no", 1, 18, 5, 5), (0, "male", 37, 15, "yes", 4, 20, 6, 4), (0, "female", 32, 15, "yes", 4, 18, 1, 4), (0, "female", 32, 7, "yes", 4, 17, 5, 4), (0, "female", 42, 15, "yes", 3, 14, 1, 3), (0, "female", 27, 7, "yes", 3, 16, 1, 4), (0, "male", 27, 1.5, "no", 3, 16, 4, 2), (0, "male", 22, 1.5, "no", 3, 16, 3, 5), (0, "male", 27, 4, "yes", 3, 16, 4, 2), (0, "female", 27, 7, "yes", 3, 12, 1, 2), (0, "female", 37, 15, "yes", 2, 18, 5, 4), (0, "female", 37, 7, "yes", 3, 14, 4, 4), (0, "male", 22, 1.5, "no", 2, 16, 5, 5), (0, "male", 37, 15, "yes", 5, 20, 5, 4), (0, "female", 22, 1.5, "no", 4, 16, 5, 3), (0, "female", 32, 10, "yes", 4, 16, 1, 5), (0, "male", 27, 4, "no", 2, 17, 5, 3), (0, "female", 22, 0.417, "no", 4, 14, 5, 5), (0, "female", 27, 4, "no", 2, 18, 5, 5), (0, "male", 37, 15, "yes", 4, 18, 5, 3), (0, "male", 37, 10, "yes", 5, 20, 7, 4), (0, "female", 27, 7, "yes", 2, 14, 4, 2), (0, "male", 32, 4, "yes", 2, 16, 5, 5), (0, "male", 32, 4, "yes", 2, 16, 6, 4), (0, "male", 22, 1.5, "no", 3, 18, 4, 5), (0, "female", 22, 4, "yes", 4, 14, 3, 4), (0, "female", 17.5, 0.75, "no", 2, 18, 5, 4), (0, "male", 32, 10, "yes", 4, 20, 4, 5), (0, "female", 32, 0.75, "no", 5, 14, 3, 3), (0, "male", 37, 15, "yes", 4, 17, 5, 3), (0, "male", 32, 4, "no", 3, 14, 4, 5), (0, "female", 27, 1.5, "no", 2, 17, 3, 2), (0, "female", 22, 7, "yes", 4, 14, 1, 5), (0, "male", 47, 15, "yes", 5, 14, 6, 5), (0, "male", 27, 4, "yes", 1, 16, 4, 4), (0, "female", 37, 15, "yes", 5, 14, 1, 3), (0, "male", 42, 4, "yes", 4, 18, 5, 5), (0, "female", 32, 4, "yes", 2, 14, 1, 5), (0, "male", 52, 15, "yes", 2, 14, 7, 4), (0, "female", 22, 1.5, "no", 2, 16, 1, 4), (0, "male", 52, 15, "yes", 4, 12, 2, 4), (0, "female", 22, 0.417, "no", 3, 17, 1, 5), (0, "female", 22, 1.5, "no", 2, 16, 5, 5), (0, "male", 27, 4, "yes", 4, 20, 6, 4), (0, "female", 32, 15, "yes", 4, 14, 1, 5), (0, "female", 27, 1.5, "no", 2, 16, 3, 5), (0, "male", 32, 4, "no", 1, 20, 6, 5), (0, "male", 37, 15, "yes", 3, 20, 6, 4), (0, "female", 32, 10, "no", 2, 16, 6, 5), (0, "female", 32, 10, "yes", 5, 14, 5, 5), (0, "male", 37, 1.5, "yes", 4, 18, 5, 3), (0, "male", 32, 1.5, "no", 2, 18, 4, 4), (0, "female", 32, 10, "yes", 4, 14, 1, 4), (0, "female", 47, 15, "yes", 4, 18, 5, 4), (0, "female", 27, 10, "yes", 5, 12, 1, 5), (0, "male", 27, 4, "yes", 3, 16, 4, 5), (0, "female", 37, 15, "yes", 4, 12, 4, 2), (0, "female", 27, 0.75, "no", 4, 16, 5, 5), (0, "female", 37, 15, "yes", 4, 16, 1, 5), (0, "female", 32, 15, "yes", 3, 16, 1, 5), (0, "female", 27, 10, "yes", 2, 16, 1, 5), (0, "male", 27, 7, "no", 2, 20, 6, 5), (0, "female", 37, 15, "yes", 2, 14, 1, 3), (0, "male", 27, 1.5, "yes", 2, 17, 4, 4), (0, "female", 22, 0.75, "yes", 2, 14, 1, 5), (0, "male", 22, 4, "yes", 4, 14, 2, 4), (0, "male", 42, 0.125, "no", 4, 17, 6, 4), (0, "male", 27, 1.5, "yes", 4, 18, 6, 5), (0, "male", 27, 7, "yes", 3, 16, 6, 3), (0, "female", 52, 15, "yes", 4, 14, 1, 3), (0, "male", 27, 1.5, "no", 5, 20, 5, 2), (0, "female", 27, 1.5, "no", 2, 16, 5, 5), (0, "female", 27, 1.5, "no", 3, 17, 5, 5), (0, "male", 22, 0.125, "no", 5, 16, 4, 4), (0, "female", 27, 4, "yes", 4, 16, 1, 5), (0, "female", 27, 4, "yes", 4, 12, 1, 5), (0, "female", 47, 15, "yes", 2, 14, 5, 5), (0, "female", 32, 15, "yes", 3, 14, 5, 3), (0, "male", 42, 7, "yes", 2, 16, 5, 5), (0, "male", 22, 0.75, "no", 4, 16, 6, 4), (0, "male", 27, 0.125, "no", 3, 20, 6, 5), (0, "male", 32, 10, "yes", 3, 20, 6, 5), (0, "female", 22, 0.417, "no", 5, 14, 4, 5), (0, "female", 47, 15, "yes", 5, 14, 1, 4), (0, "female", 32, 10, "yes", 3, 14, 1, 5), (0, "male", 57, 15, "yes", 4, 17, 5, 5), (0, "male", 27, 4, "yes", 3, 20, 6, 5), (0, "female", 32, 7, "yes", 4, 17, 1, 5), (0, "female", 37, 10, "yes", 4, 16, 1, 5), (0, "female", 32, 10, "yes", 1, 18, 1, 4), (0, "female", 22, 4, "no", 3, 14, 1, 4), (0, "female", 27, 7, "yes", 4, 14, 3, 2), (0, "male", 57, 15, "yes", 5, 18, 5, 2), (0, "male", 32, 7, "yes", 2, 18, 5, 5), (0, "female", 27, 1.5, "no", 4, 17, 1, 3), (0, "male", 22, 1.5, "no", 4, 14, 5, 5), (0, "female", 22, 1.5, "yes", 4, 14, 5, 4), (0, "female", 32, 7, "yes", 3, 16, 1, 5), (0, "female", 47, 15, "yes", 3, 16, 5, 4), (0, "female", 22, 0.75, "no", 3, 16, 1, 5), (0, "female", 22, 1.5, "yes", 2, 14, 5, 5), (0, "female", 27, 4, "yes", 1, 16, 5, 5), (0, "male", 52, 15, "yes", 4, 16, 5, 5), (0, "male", 32, 10, "yes", 4, 20, 6, 5), (0, "male", 47, 15, "yes", 4, 16, 6, 4), (0, "female", 27, 7, "yes", 2, 14, 1, 2), (0, "female", 22, 1.5, "no", 4, 14, 4, 5), (0, "female", 32, 10, "yes", 2, 16, 5, 4), (0, "female", 22, 0.75, "no", 2, 16, 5, 4), (0, "female", 22, 1.5, "no", 2, 16, 5, 5), (0, "female", 42, 15, "yes", 3, 18, 6, 4), (0, "female", 27, 7, "yes", 5, 14, 4, 5), (0, "male", 42, 15, "yes", 4, 16, 4, 4), (0, "female", 57, 15, "yes", 3, 18, 5, 2), (0, "male", 42, 15, "yes", 3, 18, 6, 2), (0, "female", 32, 7, "yes", 2, 14, 1, 2), (0, "male", 22, 4, "no", 5, 12, 4, 5), (0, "female", 22, 1.5, "no", 1, 16, 6, 5), (0, "female", 22, 0.75, "no", 1, 14, 4, 5), (0, "female", 32, 15, "yes", 4, 12, 1, 5), (0, "male", 22, 1.5, "no", 2, 18, 5, 3), (0, "male", 27, 4, "yes", 5, 17, 2, 5), (0, "female", 27, 4, "yes", 4, 12, 1, 5), (0, "male", 42, 15, "yes", 5, 18, 5, 4), (0, "male", 32, 1.5, "no", 2, 20, 7, 3), (0, "male", 57, 15, "no", 4, 9, 3, 1), (0, "male", 37, 7, "no", 4, 18, 5, 5), (0, "male", 52, 15, "yes", 2, 17, 5, 4), (0, "male", 47, 15, "yes", 4, 17, 6, 5), (0, "female", 27, 7, "no", 2, 17, 5, 4), (0, "female", 27, 7, "yes", 4, 14, 5, 5), (0, "female", 22, 4, "no", 2, 14, 3, 3), (0, "male", 37, 7, "yes", 2, 20, 6, 5), (0, "male", 27, 7, "no", 4, 12, 4, 3), (0, "male", 42, 10, "yes", 4, 18, 6, 4), (0, "female", 22, 1.5, "no", 3, 14, 1, 5), (0, "female", 22, 4, "yes", 2, 14, 1, 3), (0, "female", 57, 15, "no", 4, 20, 6, 5), (0, "male", 37, 15, "yes", 4, 14, 4, 3), (0, "female", 27, 7, "yes", 3, 18, 5, 5), (0, "female", 17.5, 10, "no", 4, 14, 4, 5), (0, "male", 22, 4, "yes", 4, 16, 5, 5), (0, "female", 27, 4, "yes", 2, 16, 1, 4), (0, "female", 37, 15, "yes", 2, 14, 5, 1), (0, "female", 22, 1.5, "no", 5, 14, 1, 4), (0, "male", 27, 7, "yes", 2, 20, 5, 4), (0, "male", 27, 4, "yes", 4, 14, 5, 5), (0, "male", 22, 0.125, "no", 1, 16, 3, 5), (0, "female", 27, 7, "yes", 4, 14, 1, 4), (0, "female", 32, 15, "yes", 5, 16, 5, 3), (0, "male", 32, 10, "yes", 4, 18, 5, 4), (0, "female", 32, 15, "yes", 2, 14, 3, 4), (0, "female", 22, 1.5, "no", 3, 17, 5, 5), (0, "male", 27, 4, "yes", 4, 17, 4, 4), (0, "female", 52, 15, "yes", 5, 14, 1, 5), (0, "female", 27, 7, "yes", 2, 12, 1, 2), (0, "female", 27, 7, "yes", 3, 12, 1, 4), (0, "female", 42, 15, "yes", 2, 14, 1, 4), (0, "female", 42, 15, "yes", 4, 14, 5, 4), (0, "male", 27, 7, "yes", 4, 14, 3, 3), (0, "male", 27, 7, "yes", 2, 20, 6, 2), (0, "female", 42, 15, "yes", 3, 12, 3, 3), (0, "male", 27, 4, "yes", 3, 16, 3, 5), (0, "female", 27, 7, "yes", 3, 14, 1, 4), (0, "female", 22, 1.5, "no", 2, 14, 4, 5), (0, "female", 27, 4, "yes", 4, 14, 1, 4), (0, "female", 22, 4, "no", 4, 14, 5, 5), (0, "female", 22, 1.5, "no", 2, 16, 4, 5), (0, "male", 47, 15, "no", 4, 14, 5, 4), (0, "male", 37, 10, "yes", 2, 18, 6, 2), (0, "male", 37, 15, "yes", 3, 17, 5, 4), (0, "female", 27, 4, "yes", 2, 16, 1, 4), (3, "male", 27, 1.5, "no", 3, 18, 4, 4), (3, "female", 27, 4, "yes", 3, 17, 1, 5), (7, "male", 37, 15, "yes", 5, 18, 6, 2), (12, "female", 32, 10, "yes", 3, 17, 5, 2), (1, "male", 22, 0.125, "no", 4, 16, 5, 5), (1, "female", 22, 1.5, "yes", 2, 14, 1, 5), (12, "male", 37, 15, "yes", 4, 14, 5, 2), (7, "female", 22, 1.5, "no", 2, 14, 3, 4), (2, "male", 37, 15, "yes", 2, 18, 6, 4), (3, "female", 32, 15, "yes", 4, 12, 3, 2), (1, "female", 37, 15, "yes", 4, 14, 4, 2), (7, "female", 42, 15, "yes", 3, 17, 1, 4), (12, "female", 42, 15, "yes", 5, 9, 4, 1), (12, "male", 37, 10, "yes", 2, 20, 6, 2), (12, "female", 32, 15, "yes", 3, 14, 1, 2), (3, "male", 27, 4, "no", 1, 18, 6, 5), (7, "male", 37, 10, "yes", 2, 18, 7, 3), (7, "female", 27, 4, "no", 3, 17, 5, 5), (1, "male", 42, 15, "yes", 4, 16, 5, 5), (1, "female", 47, 15, "yes", 5, 14, 4, 5), (7, "female", 27, 4, "yes", 3, 18, 5, 4), (1, "female", 27, 7, "yes", 5, 14, 1, 4), (12, "male", 27, 1.5, "yes", 3, 17, 5, 4), (12, "female", 27, 7, "yes", 4, 14, 6, 2), (3, "female", 42, 15, "yes", 4, 16, 5, 4), (7, "female", 27, 10, "yes", 4, 12, 7, 3), (1, "male", 27, 1.5, "no", 2, 18, 5, 2), (1, "male", 32, 4, "no", 4, 20, 6, 4), (1, "female", 27, 7, "yes", 3, 14, 1, 3), (3, "female", 32, 10, "yes", 4, 14, 1, 4), (3, "male", 27, 4, "yes", 2, 18, 7, 2), (1, "female", 17.5, 0.75, "no", 5, 14, 4, 5), (1, "female", 32, 10, "yes", 4, 18, 1, 5), (7, "female", 32, 7, "yes", 2, 17, 6, 4), (7, "male", 37, 15, "yes", 2, 20, 6, 4), (7, "female", 37, 10, "no", 1, 20, 5, 3), (12, "female", 32, 10, "yes", 2, 16, 5, 5), (7, "male", 52, 15, "yes", 2, 20, 6, 4), (7, "female", 42, 15, "yes", 1, 12, 1, 3), (1, "male", 52, 15, "yes", 2, 20, 6, 3), (2, "male", 37, 15, "yes", 3, 18, 6, 5), (12, "female", 22, 4, "no", 3, 12, 3, 4), (12, "male", 27, 7, "yes", 1, 18, 6, 2), (1, "male", 27, 4, "yes", 3, 18, 5, 5), (12, "male", 47, 15, "yes", 4, 17, 6, 5), (12, "female", 42, 15, "yes", 4, 12, 1, 1), (7, "male", 27, 4, "no", 3, 14, 3, 4), (7, "female", 32, 7, "yes", 4, 18, 4, 5), (1, "male", 32, 0.417, "yes", 3, 12, 3, 4), (3, "male", 47, 15, "yes", 5, 16, 5, 4), (12, "male", 37, 15, "yes", 2, 20, 5, 4), (7, "male", 22, 4, "yes", 2, 17, 6, 4), (1, "male", 27, 4, "no", 2, 14, 4, 5), (7, "female", 52, 15, "yes", 5, 16, 1, 3), (1, "male", 27, 4, "no", 3, 14, 3, 3), (1, "female", 27, 10, "yes", 4, 16, 1, 4), (1, "male", 32, 7, "yes", 3, 14, 7, 4), (7, "male", 32, 7, "yes", 2, 18, 4, 1), (3, "male", 22, 1.5, "no", 1, 14, 3, 2), (7, "male", 22, 4, "yes", 3, 18, 6, 4), (7, "male", 42, 15, "yes", 4, 20, 6, 4), (2, "female", 57, 15, "yes", 1, 18, 5, 4), (7, "female", 32, 4, "yes", 3, 18, 5, 2), (1, "male", 27, 4, "yes", 1, 16, 4, 4), (7, "male", 32, 7, "yes", 4, 16, 1, 4), (2, "male", 57, 15, "yes", 1, 17, 4, 4), (7, "female", 42, 15, "yes", 4, 14, 5, 2), (7, "male", 37, 10, "yes", 1, 18, 5, 3), (3, "male", 42, 15, "yes", 3, 17, 6, 1), (1, "female", 52, 15, "yes", 3, 14, 4, 4), (2, "female", 27, 7, "yes", 3, 17, 5, 3), (12, "male", 32, 7, "yes", 2, 12, 4, 2), (1, "male", 22, 4, "no", 4, 14, 2, 5), (3, "male", 27, 7, "yes", 3, 18, 6, 4), (12, "female", 37, 15, "yes", 1, 18, 5, 5), (7, "female", 32, 15, "yes", 3, 17, 1, 3), (7, "female", 27, 7, "no", 2, 17, 5, 5), (1, "female", 32, 7, "yes", 3, 17, 5, 3), (1, "male", 32, 1.5, "yes", 2, 14, 2, 4), (12, "female", 42, 15, "yes", 4, 14, 1, 2), (7, "male", 32, 10, "yes", 3, 14, 5, 4), (7, "male", 37, 4, "yes", 1, 20, 6, 3), (1, "female", 27, 4, "yes", 2, 16, 5, 3), (12, "female", 42, 15, "yes", 3, 14, 4, 3), (1, "male", 27, 10, "yes", 5, 20, 6, 5), (12, "male", 37, 10, "yes", 2, 20, 6, 2), (12, "female", 27, 7, "yes", 1, 14, 3, 3), (3, "female", 27, 7, "yes", 4, 12, 1, 2), (3, "male", 32, 10, "yes", 2, 14, 4, 4), (12, "female", 17.5, 0.75, "yes", 2, 12, 1, 3), (12, "female", 32, 15, "yes", 3, 18, 5, 4), (2, "female", 22, 7, "no", 4, 14, 4, 3), (1, "male", 32, 7, "yes", 4, 20, 6, 5), (7, "male", 27, 4, "yes", 2, 18, 6, 2), (1, "female", 22, 1.5, "yes", 5, 14, 5, 3), (12, "female", 32, 15, "no", 3, 17, 5, 1), (12, "female", 42, 15, "yes", 2, 12, 1, 2), (7, "male", 42, 15, "yes", 3, 20, 5, 4), (12, "male", 32, 10, "no", 2, 18, 4, 2), (12, "female", 32, 15, "yes", 3, 9, 1, 1), (7, "male", 57, 15, "yes", 5, 20, 4, 5), (12, "male", 47, 15, "yes", 4, 20, 6, 4), (2, "female", 42, 15, "yes", 2, 17, 6, 3), (12, "male", 37, 15, "yes", 3, 17, 6, 3), (12, "male", 37, 15, "yes", 5, 17, 5, 2), (7, "male", 27, 10, "yes", 2, 20, 6, 4), (2, "male", 37, 15, "yes", 2, 16, 5, 4), (12, "female", 32, 15, "yes", 1, 14, 5, 2), (7, "male", 32, 10, "yes", 3, 17, 6, 3), (2, "male", 37, 15, "yes", 4, 18, 5, 1), (7, "female", 27, 1.5, "no", 2, 17, 5, 5), (3, "female", 47, 15, "yes", 2, 17, 5, 2), (12, "male", 37, 15, "yes", 2, 17, 5, 4), (12, "female", 27, 4, "no", 2, 14, 5, 5), (2, "female", 27, 10, "yes", 4, 14, 1, 5), (1, "female", 22, 4, "yes", 3, 16, 1, 3), (12, "male", 52, 7, "no", 4, 16, 5, 5), (2, "female", 27, 4, "yes", 1, 16, 3, 5), (7, "female", 37, 15, "yes", 2, 17, 6, 4), (2, "female", 27, 4, "no", 1, 17, 3, 1), (12, "female", 17.5, 0.75, "yes", 2, 12, 3, 5), (7, "female", 32, 15, "yes", 5, 18, 5, 4), (7, "female", 22, 4, "no", 1, 16, 3, 5), (2, "male", 32, 4, "yes", 4, 18, 6, 4), (1, "female", 22, 1.5, "yes", 3, 18, 5, 2), (3, "female", 42, 15, "yes", 2, 17, 5, 4), (1, "male", 32, 7, "yes", 4, 16, 4, 4), (12, "male", 37, 15, "no", 3, 14, 6, 2), (1, "male", 42, 15, "yes", 3, 16, 6, 3), (1, "male", 27, 4, "yes", 1, 18, 5, 4), (2, "male", 37, 15, "yes", 4, 20, 7, 3), (7, "male", 37, 15, "yes", 3, 20, 6, 4), (3, "male", 22, 1.5, "no", 2, 12, 3, 3), (3, "male", 32, 4, "yes", 3, 20, 6, 2), (2, "male", 32, 15, "yes", 5, 20, 6, 5), (12, "female", 52, 15, "yes", 1, 18, 5, 5), (12, "male", 47, 15, "no", 1, 18, 6, 5), (3, "female", 32, 15, "yes", 4, 16, 4, 4), (7, "female", 32, 15, "yes", 3, 14, 3, 2), (7, "female", 27, 7, "yes", 4, 16, 1, 2), (12, "male", 42, 15, "yes", 3, 18, 6, 2), (7, "female", 42, 15, "yes", 2, 14, 3, 2), (12, "male", 27, 7, "yes", 2, 17, 5, 4), (3, "male", 32, 10, "yes", 4, 14, 4, 3), (7, "male", 47, 15, "yes", 3, 16, 4, 2), (1, "male", 22, 1.5, "yes", 1, 12, 2, 5), (7, "female", 32, 10, "yes", 2, 18, 5, 4), (2, "male", 32, 10, "yes", 2, 17, 6, 5), (2, "male", 22, 7, "yes", 3, 18, 6, 2), (1, "female", 32, 15, "yes", 3, 14, 1, 5)) val data = dataList.toDF("affairs", "gender", "age", "yearsmarried", "children", "religiousness", "education", "occupation", "rating")
随机森林建模
data.createOrReplaceTempView("data") // 字符类型转换成数值 val labelWhere = "case when affairs=0 then 0 else cast(1 as double) end as label" val genderWhere = "case when gender='female' then 0 else cast(1 as double) end as gender" val childrenWhere = "case when children='no' then 0 else cast(1 as double) end as children" val dataLabelDF = spark.sql(s"select $labelWhere, $genderWhere,age,yearsmarried,$childrenWhere,religiousness,education,occupation,rating from data") val featuresArray = Array("gender", "age", "yearsmarried", "children", "religiousness", "education", "occupation", "rating") // 字段转换成特征向量 val assembler = new VectorAssembler().setInputCols(featuresArray).setOutputCol("features") val vecDF: DataFrame = assembler.transform(dataLabelDF) vecDF.show(10, truncate = false) // 将数据分为训练和测试集(30%进行测试) val Array(trainingDF, testDF) = vecDF.randomSplit(Array(0.7, 0.3)) // 索引标签,将元数据添加到标签列中 val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(vecDF) //labelIndexer.transform(vecDF).show(10, truncate = false) // 自动识别分类的特征,并对它们进行索引 // 具有大于5个不同的值的特征被视为连续。 val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(5).fit(vecDF) //featureIndexer.transform(vecDF).show(10, truncate = false) // 训练随机森林模型 val rf = new RandomForestClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setNumTrees(10) // 将索引标签转换回原始标签 val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels) // Chain indexers and forest in a Pipeline. val pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, rf, labelConverter)) // Train model. This also runs the indexers. val model = pipeline.fit(trainingDF) // 输出随机森林模型的全部参数值 model.stages(2).extractParamMap() // 作出预测 val predictions = model.transform(testDF) // Select example rows to display. predictions.select("predictedLabel", "label", "features").show(10, false) // 选择(预测标签,实际标签),并计算测试误差 val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) println("Test Error = " + (1.0 - accuracy)) // 这里的stages(2)中的“2”对应pipeline中的“rf”,将model强制转换为RandomForestClassificationModel类型 val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel] println("Learned classification forest model:\n" + rfModel.toDebugString)
代码执行结果
vecDF.show(10, truncate = false) +-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+ |label|gender|age |yearsmarried|children|religiousness|education|occupation|rating|features | +-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+ |0.0 |1.0 |37.0|10.0 |0.0 |3.0 |18.0 |7.0 |4.0 |[1.0,37.0,10.0,0.0,3.0,18.0,7.0,4.0]| |0.0 |0.0 |27.0|4.0 |0.0 |4.0 |14.0 |6.0 |4.0 |[0.0,27.0,4.0,0.0,4.0,14.0,6.0,4.0] | |0.0 |0.0 |32.0|15.0 |1.0 |1.0 |12.0 |1.0 |4.0 |[0.0,32.0,15.0,1.0,1.0,12.0,1.0,4.0]| |0.0 |1.0 |57.0|15.0 |1.0 |5.0 |18.0 |6.0 |5.0 |[1.0,57.0,15.0,1.0,5.0,18.0,6.0,5.0]| |0.0 |1.0 |22.0|0.75 |0.0 |2.0 |17.0 |6.0 |3.0 |[1.0,22.0,0.75,0.0,2.0,17.0,6.0,3.0]| |0.0 |0.0 |32.0|1.5 |0.0 |2.0 |17.0 |5.0 |5.0 |[0.0,32.0,1.5,0.0,2.0,17.0,5.0,5.0] | |0.0 |0.0 |22.0|0.75 |0.0 |2.0 |12.0 |1.0 |3.0 |[0.0,22.0,0.75,0.0,2.0,12.0,1.0,3.0]| |0.0 |1.0 |57.0|15.0 |1.0 |2.0 |14.0 |4.0 |4.0 |[1.0,57.0,15.0,1.0,2.0,14.0,4.0,4.0]| |0.0 |0.0 |32.0|15.0 |1.0 |4.0 |16.0 |1.0 |2.0 |[0.0,32.0,15.0,1.0,4.0,16.0,1.0,2.0]| |0.0 |1.0 |22.0|1.5 |0.0 |4.0 |14.0 |4.0 |5.0 |[1.0,22.0,1.5,0.0,4.0,14.0,4.0,5.0] | +-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+ only showing top 10 rows // 将数据分为训练和测试集(30%进行测试) val Array(trainingDF, testDF) = vecDF.randomSplit(Array(0.7, 0.3)) trainingDF: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [label: double, gender: double ... 8 more fields] testDF: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [label: double, gender: double ... 8 more fields] // 索引标签,将元数据添加到标签列中 val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(vecDF) labelIndexer: org.apache.spark.ml.feature.StringIndexerModel = strIdx_37df210602df //labelIndexer.transform(vecDF).show(10, truncate = false) // 自动识别分类的特征,并对它们进行索引 // 具有大于5个不同的值的特征被视为连续。 val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(5).fit(vecDF) featureIndexer: org.apache.spark.ml.feature.VectorIndexerModel = vecIdx_9595c228f520 //featureIndexer.transform(vecDF).show(10, truncate = false) // 训练随机森林模型 val rf = new RandomForestClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setNumTrees(10) rf: org.apache.spark.ml.classification.RandomForestClassifier = rfc_d0e7623d0b10 // 将索引标签转换回原始标签 val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels) labelConverter: org.apache.spark.ml.feature.IndexToString = idxToStr_32d6938f2c94 // Chain indexers and forest in a Pipeline. val pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, rf, labelConverter)) pipeline: org.apache.spark.ml.Pipeline = pipeline_97716da42fed // Train model. This also runs the indexers. val model = pipeline.fit(trainingDF) model: org.apache.spark.ml.PipelineModel = pipeline_97716da42fed // 输出随机森林模型的全部参数值 model.stages(2).extractParamMap() res10: org.apache.spark.ml.param.ParamMap = { rfc_0d830180d598-cacheNodeIds: false, rfc_0d830180d598-checkpointInterval: 10, rfc_0d830180d598-featureSubsetStrategy: auto, rfc_0d830180d598-featuresCol: indexedFeatures, rfc_0d830180d598-impurity: gini, rfc_0d830180d598-labelCol: indexedLabel, rfc_0d830180d598-maxBins: 32, rfc_0d830180d598-maxDepth: 5, rfc_0d830180d598-maxMemoryInMB: 256, rfc_0d830180d598-minInfoGain: 0.0, rfc_0d830180d598-minInstancesPerNode: 1, rfc_0d830180d598-predictionCol: prediction, rfc_0d830180d598-probabilityCol: probability, rfc_0d830180d598-rawPredictionCol: rawPrediction, rfc_0d830180d598-seed: 207336481, rfc_0d830180d598-subsamplingRate: 1.0 } // 作出预测 val predictions = model.transform(testDF) predictions: org.apache.spark.sql.DataFrame = [label: double, gender: double ... 14 more fields] predictions.select("predictedLabel", "label", "features").show(10,false) +--------------+-----+-------------------------------------+ |predictedLabel|label|features | +--------------+-----+-------------------------------------+ |0.0 |0.0 |[0.0,22.0,0.125,0.0,4.0,12.0,4.0,5.0]| |0.0 |0.0 |[0.0,22.0,0.125,0.0,4.0,14.0,4.0,5.0]| |0.0 |0.0 |[0.0,22.0,0.417,0.0,1.0,17.0,6.0,4.0]| |0.0 |0.0 |[0.0,22.0,0.417,0.0,4.0,14.0,5.0,5.0]| |0.0 |0.0 |[0.0,22.0,0.417,1.0,3.0,14.0,3.0,5.0]| |0.0 |0.0 |[0.0,22.0,0.75,0.0,5.0,18.0,1.0,5.0] | |0.0 |0.0 |[0.0,22.0,1.5,0.0,1.0,14.0,1.0,5.0] | |0.0 |0.0 |[0.0,22.0,1.5,0.0,4.0,16.0,5.0,3.0] | |0.0 |0.0 |[0.0,22.0,1.5,0.0,4.0,17.0,5.0,5.0] | |0.0 |0.0 |[0.0,22.0,1.5,1.0,3.0,12.0,1.0,3.0] | +--------------+-----+-------------------------------------+ only showing top 10 rows // 选择(预测标签,实际标签),并计算测试误差 val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy") evaluator: org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator = mcEval_13a195abc422 val accuracy = evaluator.evaluate(predictions) accuracy: Double = 0.7365591397849462 println("Test Error = " + (1.0 - accuracy)) Test Error = 0.26344086021505375 // 这里的stages(2)中的“2”对应pipeline中的“rf”,将model强制转换为RandomForestClassificationModel类型 val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel] rfModel: org.apache.spark.ml.classification.RandomForestClassificationModel = RandomForestClassificationModel (uid=rfc_f7bb5e488533) with 10 trees println("Learned classification forest model:\n" + rfModel.toDebugString) Learned classification forest model: RandomForestClassificationModel (uid=rfc_f7bb5e488533) with 10 trees Tree 0 (weight 1.0): If (feature 2 <= 1.5) If (feature 5 <= 12.0) If (feature 6 <= 1.0) Predict: 0.0 Else (feature 6 > 1.0) If (feature 2 <= 0.125) Predict: 0.0 Else (feature 2 > 0.125) Predict: 1.0 Else (feature 5 > 12.0) If (feature 0 in {0.0}) If (feature 5 <= 16.0) Predict: 0.0 Else (feature 5 > 16.0) If (feature 1 <= 22.0) Predict: 0.0 Else (feature 1 > 22.0) Predict: 0.0 Else (feature 0 not in {0.0}) If (feature 2 <= 0.75) If (feature 4 in {0.0,1.0,2.0,4.0}) Predict: 0.0 Else (feature 4 not in {0.0,1.0,2.0,4.0}) Predict: 0.0 Else (feature 2 > 0.75) If (feature 1 <= 22.0) Predict: 0.0 Else (feature 1 > 22.0) Predict: 1.0 Else (feature 2 > 1.5) If (feature 1 <= 42.0) If (feature 1 <= 27.0) If (feature 5 <= 16.0) If (feature 6 <= 5.0) Predict: 0.0 Else (feature 6 > 5.0) Predict: 1.0 Else (feature 5 > 16.0) If (feature 4 in {3.0}) Predict: 0.0 Else (feature 4 not in {3.0}) Predict: 0.0 Else (feature 1 > 27.0) If (feature 4 in {0.0,3.0,4.0}) If (feature 2 <= 4.0) Predict: 1.0 Else (feature 2 > 4.0) Predict: 0.0 Else (feature 4 not in {0.0,3.0,4.0}) If (feature 6 <= 4.0) Predict: 0.0 Else (feature 6 > 4.0) Predict: 1.0 Else (feature 1 > 42.0) If (feature 4 in {2.0,4.0}) Predict: 0.0 Else (feature 4 not in {2.0,4.0}) If (feature 4 in {0.0}) Predict: 1.0 Else (feature 4 not in {0.0}) If (feature 3 in {0.0}) Predict: 0.0 Else (feature 3 not in {0.0}) Predict: 0.0 Tree 1 (weight 1.0): If (feature 7 in {0.0,2.0,4.0}) If (feature 7 in {0.0}) If (feature 1 <= 42.0) If (feature 4 in {1.0}) Predict: 0.0 Else (feature 4 not in {1.0}) Predict: 1.0 Else (feature 1 > 42.0) Predict: 0.0 Else (feature 7 not in {0.0}) If (feature 1 <= 17.5) If (feature 4 in {3.0}) Predict: 0.0 Else (feature 4 not in {3.0}) Predict: 1.0 Else (feature 1 > 17.5) If (feature 0 in {0.0}) If (feature 4 in {1.0,3.0,4.0}) Predict: 0.0 Else (feature 4 not in {1.0,3.0,4.0}) Predict: 0.0 Else (feature 0 not in {0.0}) If (feature 6 <= 2.0) Predict: 1.0 Else (feature 6 > 2.0) Predict: 0.0 Else (feature 7 not in {0.0,2.0,4.0}) If (feature 3 in {0.0}) If (feature 5 <= 14.0) If (feature 4 in {1.0,3.0}) Predict: 0.0 Else (feature 4 not in {1.0,3.0}) If (feature 0 in {0.0}) Predict: 0.0 Else (feature 0 not in {0.0}) Predict: 1.0 Else (feature 5 > 14.0) If (feature 0 in {0.0}) Predict: 0.0 Else (feature 0 not in {0.0}) If (feature 4 in {0.0,2.0,3.0,4.0}) Predict: 0.0 Else (feature 4 not in {0.0,2.0,3.0,4.0}) Predict: 1.0 Else (feature 3 not in {0.0}) If (feature 5 <= 12.0) If (feature 0 in {1.0}) Predict: 0.0 Else (feature 0 not in {1.0}) If (feature 6 <= 1.0) Predict: 0.0 Else (feature 6 > 1.0) Predict: 0.0 Else (feature 5 > 12.0) If (feature 4 in {0.0,2.0,3.0,4.0}) If (feature 1 <= 47.0) Predict: 0.0 Else (feature 1 > 47.0) Predict: 1.0 Else (feature 4 not in {0.0,2.0,3.0,4.0}) If (feature 1 <= 22.0) Predict: 1.0 Else (feature 1 > 22.0) Predict: 0.0 Tree 2 (weight 1.0): If (feature 7 in {0.0}) If (feature 4 in {1.0}) Predict: 0.0 Else (feature 4 not in {1.0}) If (feature 6 <= 5.0) If (feature 1 <= 42.0) Predict: 1.0 Else (feature 1 > 42.0) Predict: 0.0 Else (feature 6 > 5.0) Predict: 0.0 Else (feature 7 not in {0.0}) If (feature 5 <= 16.0) If (feature 7 in {1.0}) If (feature 6 <= 4.0) If (feature 2 <= 7.0) Predict: 0.0 Else (feature 2 > 7.0) Predict: 1.0 Else (feature 6 > 4.0) Predict: 1.0 Else (feature 7 not in {1.0}) If (feature 3 in {1.0}) If (feature 1 <= 17.5) Predict: 1.0 Else (feature 1 > 17.5) Predict: 0.0 Else (feature 3 not in {1.0}) If (feature 0 in {0.0}) Predict: 0.0 Else (feature 0 not in {0.0}) Predict: 0.0 Else (feature 5 > 16.0) If (feature 3 in {0.0}) If (feature 4 in {4.0}) Predict: 0.0 Else (feature 4 not in {4.0}) If (feature 5 <= 18.0) Predict: 0.0 Else (feature 5 > 18.0) Predict: 0.0 Else (feature 3 not in {0.0}) If (feature 4 in {0.0,3.0,4.0}) If (feature 7 in {2.0}) Predict: 0.0 Else (feature 7 not in {2.0}) Predict: 0.0 Else (feature 4 not in {0.0,3.0,4.0}) If (feature 6 <= 4.0) Predict: 0.0 Else (feature 6 > 4.0) Predict: 1.0 Tree 3 (weight 1.0): If (feature 3 in {0.0}) If (feature 7 in {3.0}) Predict: 0.0 Else (feature 7 not in {3.0}) If (feature 2 <= 10.0) If (feature 4 in {2.0,3.0,4.0}) If (feature 4 in {4.0}) Predict: 0.0 Else (feature 4 not in {4.0}) Predict: 0.0 Else (feature 4 not in {2.0,3.0,4.0}) If (feature 7 in {0.0,2.0,4.0}) Predict: 0.0 Else (feature 7 not in {0.0,2.0,4.0}) Predict: 1.0 Else (feature 2 > 10.0) Predict: 1.0 Else (feature 3 not in {0.0}) If (feature 6 <= 2.0) If (feature 5 <= 16.0) If (feature 7 in {0.0,1.0,2.0,4.0}) If (feature 4 in {0.0,1.0,3.0,4.0}) Predict: 0.0 Else (feature 4 not in {0.0,1.0,3.0,4.0}) Predict: 1.0 Else (feature 7 not in {0.0,1.0,2.0,4.0}) If (feature 1 <= 22.0) Predict: 0.0 Else (feature 1 > 22.0) Predict: 0.0 Else (feature 5 > 16.0) If (feature 7 in {0.0,1.0,3.0}) Predict: 0.0 Else (feature 7 not in {0.0,1.0,3.0}) Predict: 1.0 Else (feature 6 > 2.0) If (feature 4 in {0.0,3.0,4.0}) If (feature 7 in {0.0,2.0,3.0,4.0}) If (feature 4 in {3.0,4.0}) Predict: 0.0 Else (feature 4 not in {3.0,4.0}) Predict: 0.0 Else (feature 7 not in {0.0,2.0,3.0,4.0}) If (feature 6 <= 4.0) Predict: 0.0 Else (feature 6 > 4.0) Predict: 1.0 Else (feature 4 not in {0.0,3.0,4.0}) If (feature 1 <= 22.0) If (feature 5 <= 14.0) Predict: 1.0 Else (feature 5 > 14.0) Predict: 1.0 Else (feature 1 > 22.0) If (feature 6 <= 6.0) Predict: 0.0 Else (feature 6 > 6.0) Predict: 1.0 Tree 4 (weight 1.0): If (feature 7 in {0.0,2.0,4.0}) If (feature 7 in {0.0}) If (feature 6 <= 5.0) If (feature 3 in {0.0}) Predict: 0.0 Else (feature 3 not in {0.0}) If (feature 4 in {2.0,4.0}) Predict: 1.0 Else (feature 4 not in {2.0,4.0}) Predict: 1.0 Else (feature 6 > 5.0) Predict: 0.0 Else (feature 7 not in {0.0}) If (feature 2 <= 1.5) If (feature 5 <= 12.0) If (feature 2 <= 0.125) Predict: 0.0 Else (feature 2 > 0.125) Predict: 0.0 Else (feature 5 > 12.0) If (feature 1 <= 17.5) Predict: 1.0 Else (feature 1 > 17.5) Predict: 0.0 Else (feature 2 > 1.5) If (feature 2 <= 7.0) If (feature 4 in {1.0,3.0,4.0}) Predict: 0.0 Else (feature 4 not in {1.0,3.0,4.0}) Predict: 0.0 Else (feature 2 > 7.0) If (feature 5 <= 16.0) Predict: 0.0 Else (feature 5 > 16.0) Predict: 0.0 Else (feature 7 not in {0.0,2.0,4.0}) If (feature 5 <= 12.0) Predict: 0.0 Else (feature 5 > 12.0) If (feature 4 in {0.0,3.0,4.0}) If (feature 1 <= 47.0) If (feature 1 <= 22.0) Predict: 0.0 Else (feature 1 > 22.0) Predict: 0.0 Else (feature 1 > 47.0) Predict: 1.0 Else (feature 4 not in {0.0,3.0,4.0}) If (feature 1 <= 27.0) If (feature 3 in {0.0}) Predict: 0.0 Else (feature 3 not in {0.0}) Predict: 0.0 Else (feature 1 > 27.0) If (feature 5 <= 14.0) Predict: 1.0 Else (feature 5 > 14.0) Predict: 1.0 Tree 5 (weight 1.0): If (feature 7 in {0.0}) If (feature 1 <= 42.0) If (feature 6 <= 4.0) Predict: 1.0 Else (feature 6 > 4.0) If (feature 4 in {1.0}) Predict: 0.0 Else (feature 4 not in {1.0}) Predict: 1.0 Else (feature 1 > 42.0) Predict: 0.0 Else (feature 7 not in {0.0}) If (feature 2 <= 1.5) If (feature 4 in {0.0,2.0,3.0}) If (feature 1 <= 22.0) If (feature 0 in {0.0}) Predict: 0.0 Else (feature 0 not in {0.0}) Predict: 0.0 Else (feature 1 > 22.0) Predict: 0.0 Else (feature 4 not in {0.0,2.0,3.0}) If (feature 1 <= 17.5) If (feature 6 <= 4.0) Predict: 1.0 Else (feature 6 > 4.0) Predict: 0.0 Else (feature 1 > 17.5) If (feature 0 in {0.0}) Predict: 0.0 Else (feature 0 not in {0.0}) Predict: 0.0 Else (feature 2 > 1.5) If (feature 6 <= 5.0) If (feature 5 <= 17.0) If (feature 7 in {2.0,4.0}) Predict: 0.0 Else (feature 7 not in {2.0,4.0}) Predict: 0.0 Else (feature 5 > 17.0) If (feature 6 <= 1.0) Predict: 0.0 Else (feature 6 > 1.0) Predict: 0.0 Else (feature 6 > 5.0) If (feature 4 in {0.0,3.0,4.0}) If (feature 7 in {3.0,4.0}) Predict: 0.0 Else (feature 7 not in {3.0,4.0}) Predict: 0.0 Else (feature 4 not in {0.0,3.0,4.0}) If (feature 6 <= 6.0) Predict: 0.0 Else (feature 6 > 6.0) Predict: 0.0 Tree 6 (weight 1.0): If (feature 4 in {0.0,3.0,4.0}) If (feature 5 <= 12.0) If (feature 7 in {1.0,2.0,3.0,4.0}) Predict: 0.0 Else (feature 7 not in {1.0,2.0,3.0,4.0}) If (feature 6 <= 3.0) Predict: 0.0 Else (feature 6 > 3.0) Predict: 1.0 Else (feature 5 > 12.0) If (feature 7 in {0.0,1.0,2.0}) If (feature 6 <= 1.0) If (feature 7 in {0.0,2.0}) Predict: 0.0 Else (feature 7 not in {0.0,2.0}) Predict: 0.0 Else (feature 6 > 1.0) If (feature 1 <= 37.0) Predict: 1.0 Else (feature 1 > 37.0) Predict: 0.0 Else (feature 7 not in {0.0,1.0,2.0}) If (feature 1 <= 17.5) If (feature 4 in {3.0}) Predict: 0.0 Else (feature 4 not in {3.0}) Predict: 1.0 Else (feature 1 > 17.5) If (feature 6 <= 4.0) Predict: 0.0 Else (feature 6 > 4.0) Predict: 0.0 Else (feature 4 not in {0.0,3.0,4.0}) If (feature 7 in {0.0,4.0}) If (feature 5 <= 12.0) If (feature 2 <= 0.125) Predict: 0.0 Else (feature 2 > 0.125) If (feature 1 <= 17.5) Predict: 1.0 Else (feature 1 > 17.5) Predict: 0.0 Else (feature 5 > 12.0) If (feature 7 in {0.0}) If (feature 1 <= 42.0) Predict: 1.0 Else (feature 1 > 42.0) Predict: 0.0 Else (feature 7 not in {0.0}) If (feature 2 <= 1.5) Predict: 0.0 Else (feature 2 > 1.5) Predict: 0.0 Else (feature 7 not in {0.0,4.0}) If (feature 6 <= 4.0) If (feature 7 in {3.0}) If (feature 0 in {0.0}) Predict: 0.0 Else (feature 0 not in {0.0}) Predict: 0.0 Else (feature 7 not in {3.0}) If (feature 5 <= 16.0) Predict: 0.0 Else (feature 5 > 16.0) Predict: 1.0 Else (feature 6 > 4.0) If (feature 6 <= 6.0) If (feature 3 in {0.0}) Predict: 0.0 Else (feature 3 not in {0.0}) Predict: 1.0 Else (feature 6 > 6.0) If (feature 5 <= 18.0) Predict: 1.0 Else (feature 5 > 18.0) Predict: 0.0 Tree 7 (weight 1.0): If (feature 7 in {0.0,2.0,4.0}) If (feature 2 <= 1.5) If (feature 4 in {1.0,2.0,3.0}) If (feature 1 <= 17.5) Predict: 1.0 Else (feature 1 > 17.5) Predict: 0.0 Else (feature 4 not in {1.0,2.0,3.0}) If (feature 5 <= 14.0) If (feature 0 in {0.0}) Predict: 0.0 Else (feature 0 not in {0.0}) Predict: 1.0 Else (feature 5 > 14.0) Predict: 0.0 Else (feature 2 > 1.5) If (feature 7 in {0.0,2.0}) If (feature 4 in {1.0,3.0,4.0}) If (feature 5 <= 16.0) Predict: 0.0 Else (feature 5 > 16.0) Predict: 0.0 Else (feature 4 not in {1.0,3.0,4.0}) If (feature 6 <= 5.0) Predict: 1.0 Else (feature 6 > 5.0) Predict: 0.0 Else (feature 7 not in {0.0,2.0}) If (feature 4 in {0.0,1.0,3.0}) If (feature 1 <= 42.0) Predict: 0.0 Else (feature 1 > 42.0) Predict: 0.0 Else (feature 4 not in {0.0,1.0,3.0}) If (feature 5 <= 16.0) Predict: 0.0 Else (feature 5 > 16.0) Predict: 0.0 Else (feature 7 not in {0.0,2.0,4.0}) If (feature 2 <= 0.75) Predict: 0.0 Else (feature 2 > 0.75) If (feature 4 in {4.0}) If (feature 6 <= 5.0) If (feature 1 <= 37.0) Predict: 1.0 Else (feature 1 > 37.0) Predict: 0.0 Else (feature 6 > 5.0) Predict: 0.0 Else (feature 4 not in {4.0}) If (feature 5 <= 12.0) If (feature 1 <= 27.0) Predict: 0.0 Else (feature 1 > 27.0) Predict: 0.0 Else (feature 5 > 12.0) If (feature 7 in {1.0}) Predict: 1.0 Else (feature 7 not in {1.0}) Predict: 0.0 Tree 8 (weight 1.0): If (feature 5 <= 16.0) If (feature 4 in {0.0,1.0}) If (feature 0 in {0.0}) If (feature 2 <= 0.75) If (feature 1 <= 17.5) Predict: 1.0 Else (feature 1 > 17.5) Predict: 0.0 Else (feature 2 > 0.75) If (feature 6 <= 4.0) Predict: 0.0 Else (feature 6 > 4.0) Predict: 0.0 Else (feature 0 not in {0.0}) If (feature 5 <= 12.0) Predict: 1.0 Else (feature 5 > 12.0) If (feature 7 in {2.0,4.0}) Predict: 0.0 Else (feature 7 not in {2.0,4.0}) Predict: 0.0 Else (feature 4 not in {0.0,1.0}) If (feature 7 in {0.0,2.0,3.0,4.0}) If (feature 1 <= 22.0) If (feature 6 <= 3.0) Predict: 0.0 Else (feature 6 > 3.0) Predict: 0.0 Else (feature 1 > 22.0) If (feature 6 <= 6.0) Predict: 0.0 Else (feature 6 > 6.0) Predict: 1.0 Else (feature 7 not in {0.0,2.0,3.0,4.0}) If (feature 1 <= 42.0) If (feature 6 <= 4.0) Predict: 0.0 Else (feature 6 > 4.0) Predict: 1.0 Else (feature 1 > 42.0) Predict: 0.0 Else (feature 5 > 16.0) If (feature 5 <= 18.0) If (feature 4 in {3.0}) If (feature 7 in {1.0,2.0,3.0}) Predict: 0.0 Else (feature 7 not in {1.0,2.0,3.0}) If (feature 6 <= 5.0) Predict: 0.0 Else (feature 6 > 5.0) Predict: 0.0 Else (feature 4 not in {3.0}) If (feature 2 <= 0.75) Predict: 0.0 Else (feature 2 > 0.75) If (feature 3 in {0.0}) Predict: 0.0 Else (feature 3 not in {0.0}) Predict: 1.0 Else (feature 5 > 18.0) If (feature 1 <= 27.0) If (feature 7 in {3.0}) If (feature 3 in {0.0}) Predict: 0.0 Else (feature 3 not in {0.0}) Predict: 1.0 Else (feature 7 not in {3.0}) If (feature 2 <= 4.0) Predict: 0.0 Else (feature 2 > 4.0) Predict: 1.0 Else (feature 1 > 27.0) If (feature 6 <= 5.0) If (feature 6 <= 4.0) Predict: 0.0 Else (feature 6 > 4.0) Predict: 0.0 Else (feature 6 > 5.0) If (feature 4 in {3.0,4.0}) Predict: 0.0 Else (feature 4 not in {3.0,4.0}) Predict: 0.0 Tree 9 (weight 1.0): If (feature 5 <= 16.0) If (feature 6 <= 2.0) If (feature 1 <= 42.0) If (feature 6 <= 1.0) If (feature 5 <= 9.0) Predict: 1.0 Else (feature 5 > 9.0) Predict: 0.0 Else (feature 6 > 1.0) If (feature 1 <= 27.0) Predict: 0.0 Else (feature 1 > 27.0) Predict: 1.0 Else (feature 1 > 42.0) Predict: 0.0 Else (feature 6 > 2.0) If (feature 1 <= 27.0) If (feature 5 <= 14.0) If (feature 6 <= 3.0) Predict: 0.0 Else (feature 6 > 3.0) Predict: 0.0 Else (feature 5 > 14.0) Predict: 0.0 Else (feature 1 > 27.0) If (feature 4 in {1.0,2.0,4.0}) If (feature 5 <= 9.0) Predict: 0.0 Else (feature 5 > 9.0) Predict: 0.0 Else (feature 4 not in {1.0,2.0,4.0}) If (feature 7 in {2.0,3.0,4.0}) Predict: 0.0 Else (feature 7 not in {2.0,3.0,4.0}) Predict: 1.0 Else (feature 5 > 16.0) If (feature 6 <= 4.0) If (feature 4 in {3.0}) Predict: 0.0 Else (feature 4 not in {3.0}) If (feature 1 <= 42.0) If (feature 3 in {0.0}) Predict: 0.0 Else (feature 3 not in {0.0}) Predict: 0.0 Else (feature 1 > 42.0) Predict: 1.0 Else (feature 6 > 4.0) If (feature 4 in {3.0,4.0}) If (feature 1 <= 37.0) If (feature 3 in {0.0}) Predict: 0.0 Else (feature 3 not in {0.0}) Predict: 0.0 Else (feature 1 > 37.0) If (feature 1 <= 42.0) Predict: 0.0 Else (feature 1 > 42.0) Predict: 0.0 Else (feature 4 not in {3.0,4.0}) If (feature 4 in {0.0,2.0}) If (feature 7 in {0.0,1.0,2.0}) Predict: 1.0 Else (feature 7 not in {0.0,1.0,2.0}) Predict: 1.0 Else (feature 4 not in {0.0,2.0}) If (feature 0 in {0.0}) Predict: 0.0 Else (feature 0 not in {0.0}) Predict: 0.0
随机森林模型调优
// 字段转换成特征向量 val assembler = new VectorAssembler().setInputCols(featuresArray).setOutputCol("features") val vecDF: DataFrame = assembler.transform(dataLabelDF) vecDF.show(10, truncate = false) // 将数据分为训练和测试集(30%进行测试) val Array(trainingDF, testDF) = vecDF.randomSplit(Array(0.7, 0.3)) // 索引标签,将元数据添加到标签列中 val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(vecDF) //labelIndexer.transform(vecDF).show(10, truncate = false) // 自动识别分类的特征,并对它们进行索引 // 具有大于5个不同的值的特征被视为连续。 val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(5).fit(vecDF) //featureIndexer.transform(vecDF).show(10, truncate = false) // 训练随机森林模型 val rf = new RandomForestClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures") // 将索引标签转换回原始标签 val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels) // Chain indexers and forest in a Pipeline. val pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, rf, labelConverter)) // 设置参数网格 //impurity 不纯度 //maxBins 离散化"连续特征"的最大划分数 //maxDepth 树的最大深度 //minInfoGain 一个节点分裂的最小信息增益,值为[0,1] //minInstancesPerNode 每个节点包含的最小样本数 >=1 //numTrees 树的数量 //featureSubsetStrategy // 在每个树节点处分割的特征数,参数值比较多,详细的请参考官方文档 //SubsamplingRate(1.0) 给每棵树分配“学习数据”的比例,范围(0, 1] //maxMemoryInMB 如果太小,则每次迭代将拆分1个节点,其聚合可能超过此大小。 //checkpointInterval 设置检查点间隔(> = 1)或禁用检查点(-1)。 例如 10意味着,每10次迭代,缓存将获得检查点。 //cacheNodeIds 如果为false,则算法将树传递给执行器以将实例与节点匹配。 如果为true,算法将缓存每个实例的节点ID。 缓存可以加速更大深度的树的训练。 用户可以通过设置checkpointInterval来设置检查或禁用缓存的频率。(default = false) //seed 种子 val paramGrid = new ParamGridBuilder() .addGrid(rf.impurity, Array("entropy", "gini")) .addGrid(rf.maxBins, Array(32, 64)) .addGrid(rf.maxDepth, Array(5, 7, 10)) .addGrid(rf.minInfoGain, Array(0, 0.5, 1)) .addGrid(rf.minInstancesPerNode, Array(10, 20)) .addGrid(rf.numTrees, Array(20, 50)) .addGrid(rf.featureSubsetStrategy, Array("auto", "sqrt")) .addGrid(rf.subsamplingRate, Array(0.8, 1)) .addGrid(rf.maxMemoryInMB, Array(256, 512)) .addGrid(rf.checkpointInterval, Array(10, 20)) .addGrid(rf.cacheNodeIds, Array(false, true)) .addGrid(rf.seed, Array(123456L, 111L)) .build() // 选择(预测标签,实际标签),并计算测试误差。indexedLabel与prediction都是索引化的,因此可以直接比较 val classEvaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy") // 设置交叉验证 val cv = new CrossValidator().setEstimator(pipeline).setEvaluator(classEvaluator).setEstimatorParamMaps(paramGrid).setNumFolds(5) // 执行交叉验证,并选择出最好的参数集合 val cvModel = cv.fit(trainingDF) // 查看全部参数 cvModel.extractParamMap() // cvModel.avgMetrics.length=cvModel.getEstimatorParamMaps.length // cvModel.avgMetrics与cvModel.getEstimatorParamMaps中的元素一一对应 cvModel.avgMetrics.length cvModel.avgMetrics // 参数对应的平均度量 cvModel.getEstimatorParamMaps.length cvModel.getEstimatorParamMaps // 参数组合的集合 cvModel.getEvaluator.extractParamMap() // 评估的参数 cvModel.getEvaluator.isLargerBetter // 评估的度量值是大的好,还是小的好 ,根据评估度量,系统会自动识别 cvModel.getNumFolds // 交叉验证的折数 //################################ // 测试模型 val predictDF: DataFrame = cvModel.transform(testDF).selectExpr( //"race","poverty","smoke","alcohol","agemth","ybirth","yschool","pc3mth", "features", "predictedLabel", "label", "features") predictDF.show(20, false)