交叉验证_自动获取模型最优超参数

package Spark_MLlib

import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}

/**
  * 调参+模型选择
  */
case class schema_source(features:Vector,label:String)
object 交叉验证_调参_逻辑回归 {
    val spark=SparkSession.builder().master("local[2]").getOrCreate()
     import spark.implicits._
  def main(args: Array[String]): Unit = {

    val data=spark.sparkContext.textFile("file:///home/soyo/桌面/spark编程测试数据/soyo.txt")
               .map(_.split(",")).map(x=>schema_source(Vectors.dense(x(0).toDouble,x(1).toDouble,x(2).toDouble,x(3).toDouble),x(4))).toDF()
    data.show()
    val labelIndexer=new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(data)
    val featuresIndexer=new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").fit(data)
    val Array(trainData,testData)=data.randomSplit(Array(0.7,0.3))

    val lr=new LogisticRegression().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setMaxIter(50)
    val labelConverter=new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)
    labelIndexer.labels.foreach(println)
    //机器学习工作流
    val lrPipeline=new Pipeline().setStages(Array(labelIndexer,featuresIndexer,lr,labelConverter))
    //交叉验证需要的模型评估
    val evaluator=new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction")
    //构造参数网格
     val paramGrid=new ParamGridBuilder().addGrid(lr.regParam,Array(0.01,0.3,0.8)).addGrid(lr.elasticNetParam,Array(0.3,0.9)).build()
    //构建机器学习工作流的交叉验证,定义验证模型,模型评估,参数网格,数据集的折叠数(交叉验证原理)
     val cv=new CrossValidator().setEstimator(lrPipeline).setEvaluator(evaluator).setEstimatorParamMaps(paramGrid).setNumFolds(3)
    //训练模型
    val cvModel=cv.fit(trainData)
    //测试数据
    val lrPrediction=cvModel.transform(testData)
    lrPrediction.show()
    val evaluator2=new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction")
    val lrAccuracy=evaluator2.evaluate(lrPrediction)
    println("准确率为: "+lrAccuracy)
    println("错误率为: "+(1-lrAccuracy))
    //获取最优模型
    val bestModel=cvModel.bestModel.asInstanceOf[PipelineModel]
    val lrModel=bestModel.stages(2).asInstanceOf[LogisticRegressionModel]
    println("二项逻辑回归模型系数矩阵: "+lrModel.coefficientMatrix)
    println("二项逻辑回归模型的截距向量: "+lrModel.interceptVector)
    println("类的数量(标签可以使用的值): "+lrModel.numClasses)
    println("模型所接受的特征的数量: "+lrModel.numFeatures)

   println("所有参数的设置为: "+lrModel.explainParams())
   println("最优的regParam的值为: "+lrModel.explainParam(lrModel.regParam))
   println("最优的elasticNetParam的值为: "+lrModel.explainParam(lrModel.elasticNetParam))
  }

}
+-----------------+-----+
|         features|label|
+-----------------+-----+
|[5.1,3.5,1.4,0.2]|soyo1|
|[4.9,3.0,1.4,0.2]|soyo1|
|[4.7,3.2,1.3,0.2]|soyo1|
|[4.6,3.1,1.5,0.2]|soyo1|
|[5.0,3.6,1.4,0.2]|soyo1|
|[5.4,3.9,1.7,0.4]|soyo1|
|[4.6,3.4,1.4,0.3]|soyo1|
|[5.0,3.4,1.5,0.2]|soyo1|
|[4.4,2.9,1.4,0.2]|soyo1|
|[4.9,3.1,1.5,0.1]|soyo1|
|[5.4,3.7,1.5,0.2]|soyo1|
|[4.8,3.4,1.6,0.2]|soyo1|
|[4.8,3.0,1.4,0.1]|soyo1|
|[4.3,3.0,1.1,0.1]|soyo1|
|[5.8,4.0,1.2,0.2]|soyo1|
|[5.7,4.4,1.5,0.4]|soyo1|
|[5.4,3.9,1.3,0.4]|soyo1|
|[5.1,3.5,1.4,0.3]|soyo1|
|[5.7,3.8,1.7,0.3]|soyo1|
|[5.1,3.8,1.5,0.3]|soyo1|
+-----------------+-----+
only showing top 20 rows

soyo2
soyo1
soyo3
+-----------------+-----+------------+-----------------+--------------------+--------------------+----------+--------------+
|         features|label|indexedLabel|  indexedFeatures|       rawPrediction|         probability|prediction|predictedLabel|
+-----------------+-----+------------+-----------------+--------------------+--------------------+----------+--------------+
|[4.3,3.0,1.1,0.1]|soyo1|         1.0|[4.3,3.0,1.1,0.1]|[-0.2949197997435...|[0.00821657808181...|       1.0|         soyo1|
|[4.4,2.9,1.4,0.2]|soyo1|         1.0|[4.4,2.9,1.4,0.2]|[-0.1436502505351...|[0.02310764702310...|       1.0|         soyo1|
|[4.6,3.1,1.5,0.2]|soyo1|         1.0|[4.6,3.1,1.5,0.2]|[-0.1980725396328...|[0.01584026165726...|       1.0|         soyo1|
|[4.8,3.0,1.4,0.1]|soyo1|         1.0|[4.8,3.0,1.4,0.1]|[-0.0360182992158...|[0.01909506488946...|       1.0|         soyo1|
|[4.8,3.1,1.6,0.2]|soyo1|         1.0|[4.8,3.1,1.6,0.2]|[-0.0963956817735...|[0.02165865158723...|       1.0|         soyo1|
|[4.8,3.4,1.6,0.2]|soyo1|         1.0|[4.8,3.4,1.6,0.2]|[-0.3305444022091...|[0.00764403083532...|       1.0|         soyo1|
|[4.9,2.4,3.3,1.0]|soyo2|         0.0|[4.9,2.4,3.3,1.0]|[0.64687664475266...|[0.83588965920895...|       0.0|         soyo2|
|[4.9,3.0,1.4,0.2]|soyo1|         1.0|[4.9,3.0,1.4,0.2]|[0.00894554123863...|[0.02696343238302...|       1.0|         soyo1|
|[5.0,3.5,1.6,0.6]|soyo1|         1.0|[5.0,3.5,1.6,0.6]|[-0.3209967599706...|[0.01781564148264...|       1.0|         soyo1|
|[5.0,3.6,1.4,0.2]|soyo1|         1.0|[5.0,3.6,1.4,0.2]|[-0.4132228265822...|[0.00370148550004...|       1.0|         soyo1|
|[5.1,3.7,1.5,0.4]|soyo1|         1.0|[5.1,3.7,1.5,0.4]|[-0.4380550804437...|[0.00533390253840...|       1.0|         soyo1|
|[5.1,3.8,1.9,0.4]|soyo1|         1.0|[5.1,3.8,1.9,0.4]|[-0.4784298068885...|[0.00593236888116...|       1.0|         soyo1|
|[5.2,2.7,3.9,1.4]|soyo2|         0.0|[5.2,2.7,3.9,1.4]|[0.60296648363520...|[0.65499655703255...|       0.0|         soyo2|
|[5.2,3.5,1.5,0.2]|soyo1|         1.0|[5.2,3.5,1.5,0.2]|[-0.2334963952443...|[0.00721300202565...|       1.0|         soyo1|
|[5.3,3.7,1.5,0.2]|soyo1|         1.0|[5.3,3.7,1.5,0.2]|[-0.3434664691509...|[0.00396451436269...|       1.0|         soyo1|
|[5.4,3.4,1.5,0.4]|soyo1|         1.0|[5.4,3.4,1.5,0.4]|[-0.0655191408567...|[0.02050202848213...|       1.0|         soyo1|
|[5.4,3.4,1.7,0.2]|soyo1|         1.0|[5.4,3.4,1.7,0.2]|[-0.0443512521479...|[0.01568504280438...|       1.0|         soyo1|
|[5.4,3.9,1.3,0.4]|soyo1|         1.0|[5.4,3.9,1.3,0.4]|[-0.4746044317663...|[0.00285607924154...|       1.0|         soyo1|
|[5.4,3.9,1.7,0.4]|soyo1|         1.0|[5.4,3.9,1.7,0.4]|[-0.4369295847326...|[0.00451151133277...|       1.0|         soyo1|
|[5.5,2.3,4.0,1.3]|soyo2|         0.0|[5.5,2.3,4.0,1.3]|[1.06413594105520...|[0.51327715648015...|       0.0|         soyo2|
+-----------------+-----+------------+-----------------+--------------------+--------------------+----------+--------------+
only showing top 20 rows

准确率为: 0.9418343292582645
错误率为: 0.05816567074173551
二项逻辑回归模型系数矩阵: 0.4612907305046201    -0.7804957347855317  0.09418711758439907  -0.011652325959556013  
-0.559055378870932    2.7385209747134933   -1.052922922424876   -2.5223769474140303    
-0.07629895224519458  -3.6867236615320547  1.0014498171011217   4.581938360185545      
二项逻辑回归模型的截距向量: [-0.039423333303658874,0.0972586768296292,-0.05783534352597033]
类的数量(标签可以使用的值): 3
模型所接受的特征的数量: 4
所有参数的设置为: aggregationDepth: suggested depth for treeAggregate (>= 2) (default: 2)
elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty (default: 0.0, current: 0.9)
family: The name of family which is a description of the label distribution to be used in the model. Supported options: auto, binomial, multinomial. (default: auto)
featuresCol: features column name (default: features, current: indexedFeatures)
fitIntercept: whether to fit an intercept term (default: true)
labelCol: label column name (default: label, current: indexedLabel)
lowerBoundsOnCoefficients: The lower bounds on coefficients if fitting under bound constrained optimization. (undefined)
lowerBoundsOnIntercepts: The lower bounds on intercepts if fitting under bound constrained optimization. (undefined)
maxIter: maximum number of iterations (>= 0) (default: 100, current: 50)
predictionCol: prediction column name (default: prediction)
probabilityCol: Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities (default: probability)
rawPredictionCol: raw prediction (a.k.a. confidence) column name (default: rawPrediction)
regParam: regularization parameter (>= 0) (default: 0.0, current: 0.01)
standardization: whether to standardize the training features before fitting the model (default: true)
threshold: threshold in binary classification prediction, in range [0, 1] (default: 0.5)
thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold (undefined)
tol: the convergence tolerance for iterative algorithms (>= 0) (default: 1.0E-6)
upperBoundsOnCoefficients: The upper bounds on coefficients if fitting under bound constrained optimization. (undefined)
upperBoundsOnIntercepts: The upper bounds on intercepts if fitting under bound constrained optimization. (undefined)
weightCol: weight column name. If this is not set or empty, we treat all instance weights as 1.0 (undefined)

最优的regParam的值为: regParam: regularization parameter (
>= 0) (default: 0.0, current: 0.01) 最优的elasticNetParam的值为: elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty (default: 0.0, current: 0.9)

 

posted @ 2017-11-13 16:07  soyosuyang  阅读(2223)  评论(0编辑  收藏  举报