Spark:交叉验证选择参数集

spark的交叉验证和python sklearn库的交叉验证不太一样,python sklearn库cross_validation用来交叉验证选择模型,然后输出得分,而模型参数的选择同交叉验证是分开的模块。
而spark的org.apache.spark.ml.tuning包下的CrossValidator交叉验证类能够同时选择不同的参数组合来进行交叉验证,然后输出最好的模型(此模型是用最好的参数组合训练的)。

CrossValidator(Estimator)fit方法进行交叉验证会输出CrossValidatorModel(Transformer)

import org.apache.spark.ml.tuning.CrossValidator
import org.apache.spark.ml.tuning.ParamGridBuilder
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.Pipeline

val dataset = spark.table("cons_feature_clean")
//这里将数据分为两部分,train部分用来进行交叉验证选择最佳模型,test部分进行模型测试
//要注意,虽然交叉验证本身也会将数据进行训练和测试的切分,但是和此处是不同的,交叉验证本身会将train再进行切分然后比较模型得分。
val Array(train, test) = dataset.randomSplit(Array(0.8, 0.2))

val labelIndexer = new StringIndexer().setInputCol("vehi_type").setOutputCol("label").fit(dataset) //根据原始数据将分类列建立索引,注意是用dataset不是用train

val vectorAssem = new VectorAssembler().setOutputCol("features")
val estimator = new DecisionTreeClassifier()

//返回以下参数的所有可能组合
val params = new ParamGridBuilder().
addGrid(vectorAssem.inputCols, 
Array(Array("sum_cast", "run_count","sum_drive_hour", "max_per_line_count"))).
addGrid(estimator.impurity, Array("entropy")).
addGrid(estimator.maxBins, Array(128, 256, 512)).
addGrid(estimator.maxDepth, Array(10, 13, 20)).
addGrid(estimator.minInfoGain, Array(0.0, 0.001)).
addGrid(estimator.minInstancesPerNode, Array(5, 10, 20)).
build()

val pipeline = new Pipeline().setStages(Array(labelIndexer, vectorAssem, estimator))
val evaluator = new MulticlassClassificationEvaluator().setMetricName("accuracy")
val cv = new CrossValidator().
setEvaluator(evaluator).
setEstimator(pipeline).
setEstimatorParamMaps(params).
setNumFolds(2)  //进行2folds交叉验证
val model = cv.fit(train)

val predictedTest = model.transform(test)
evaluator.evaluate(predictedTest)

//获取最优参数:
import org.apache.spark.ml.PipelineModel
val pipeModel = model.bestModel.asInstanceOf[PipelineModel]
val bestVA = pipeModel.stages(1).asInstanceOf[VectorAssembler]
bestVA.getInputCols  // Array[String] = Array(sum_cast, run_count, sum_drive_hour, max_per_line_count)
pipeModel.stages(2).extractParamMap
/*
org.apache.spark.ml.param.ParamMap =
{
	dtc_be7d2b335869-cacheNodeIds: false,
	dtc_be7d2b335869-checkpointInterval: 10,
	dtc_be7d2b335869-featuresCol: features,
	dtc_be7d2b335869-impurity: entropy,
	dtc_be7d2b335869-labelCol: label,
	dtc_be7d2b335869-maxBins: 512,
	dtc_be7d2b335869-maxDepth: 10,
	dtc_be7d2b335869-maxMemoryInMB: 256,
	dtc_be7d2b335869-minInfoGain: 0.0,
	dtc_be7d2b335869-minInstancesPerNode: 20,
	dtc_be7d2b335869-predictionCol: prediction,
	dtc_be7d2b335869-probabilityCol: probability,
	dtc_be7d2b335869-rawPredictionCol: rawPrediction,
	dtc_be7d2b335869-seed: 159147643
} */
posted @ 2019-01-07 13:35  xuejianbest  阅读(1053)  评论(0编辑  收藏  举报