Spark:交叉验证选择参数集
spark的交叉验证和python sklearn库的交叉验证不太一样,python sklearn库cross_validation
用来交叉验证选择模型,然后输出得分,而模型参数的选择同交叉验证是分开的模块。
而spark的org.apache.spark.ml.tuning
包下的CrossValidator
交叉验证类能够同时选择不同的参数组合来进行交叉验证,然后输出最好的模型(此模型是用最好的参数组合训练的)。
用CrossValidator(Estimator)
的fit
方法进行交叉验证会输出CrossValidatorModel(Transformer)
。
import org.apache.spark.ml.tuning.CrossValidator
import org.apache.spark.ml.tuning.ParamGridBuilder
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.Pipeline
val dataset = spark.table("cons_feature_clean")
//这里将数据分为两部分,train部分用来进行交叉验证选择最佳模型,test部分进行模型测试
//要注意,虽然交叉验证本身也会将数据进行训练和测试的切分,但是和此处是不同的,交叉验证本身会将train再进行切分然后比较模型得分。
val Array(train, test) = dataset.randomSplit(Array(0.8, 0.2))
val labelIndexer = new StringIndexer().setInputCol("vehi_type").setOutputCol("label").fit(dataset) //根据原始数据将分类列建立索引,注意是用dataset不是用train
val vectorAssem = new VectorAssembler().setOutputCol("features")
val estimator = new DecisionTreeClassifier()
//返回以下参数的所有可能组合
val params = new ParamGridBuilder().
addGrid(vectorAssem.inputCols,
Array(Array("sum_cast", "run_count","sum_drive_hour", "max_per_line_count"))).
addGrid(estimator.impurity, Array("entropy")).
addGrid(estimator.maxBins, Array(128, 256, 512)).
addGrid(estimator.maxDepth, Array(10, 13, 20)).
addGrid(estimator.minInfoGain, Array(0.0, 0.001)).
addGrid(estimator.minInstancesPerNode, Array(5, 10, 20)).
build()
val pipeline = new Pipeline().setStages(Array(labelIndexer, vectorAssem, estimator))
val evaluator = new MulticlassClassificationEvaluator().setMetricName("accuracy")
val cv = new CrossValidator().
setEvaluator(evaluator).
setEstimator(pipeline).
setEstimatorParamMaps(params).
setNumFolds(2) //进行2folds交叉验证
val model = cv.fit(train)
val predictedTest = model.transform(test)
evaluator.evaluate(predictedTest)
//获取最优参数:
import org.apache.spark.ml.PipelineModel
val pipeModel = model.bestModel.asInstanceOf[PipelineModel]
val bestVA = pipeModel.stages(1).asInstanceOf[VectorAssembler]
bestVA.getInputCols // Array[String] = Array(sum_cast, run_count, sum_drive_hour, max_per_line_count)
pipeModel.stages(2).extractParamMap
/*
org.apache.spark.ml.param.ParamMap =
{
dtc_be7d2b335869-cacheNodeIds: false,
dtc_be7d2b335869-checkpointInterval: 10,
dtc_be7d2b335869-featuresCol: features,
dtc_be7d2b335869-impurity: entropy,
dtc_be7d2b335869-labelCol: label,
dtc_be7d2b335869-maxBins: 512,
dtc_be7d2b335869-maxDepth: 10,
dtc_be7d2b335869-maxMemoryInMB: 256,
dtc_be7d2b335869-minInfoGain: 0.0,
dtc_be7d2b335869-minInstancesPerNode: 20,
dtc_be7d2b335869-predictionCol: prediction,
dtc_be7d2b335869-probabilityCol: probability,
dtc_be7d2b335869-rawPredictionCol: rawPrediction,
dtc_be7d2b335869-seed: 159147643
} */