lakala GradientBoostedTrees
/** * Created by lkl on 2017/12/6. */ import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.GradientBoostedTrees import org.apache.spark.mllib.tree.configuration.BoostingStrategy import org.apache.spark.sql.hive.HiveContext import org.apache.spark.{SparkConf, SparkContext} import scala.collection.mutable.ArrayBuffer object GradientBoostingClassificationForLK { //http://blog.csdn.net/xubo245/article/details/51499643 def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("GradientBoostingClassificationForLK") val sc = new SparkContext(conf) // sc is an existing SparkContext. val hc = new HiveContext(sc) if(args.length!=3){ println("请输入参数:trainingData对应的库名、表名、模型运行时间") System.exit(0) } //分别传入库名、表名、对比效果路径 // val database = args(0) // val table = args(1) // val date = args(2) //lkl_card_score.overdue_result_all_new_woe val format = new java.text.SimpleDateFormat("yyyyMMdd") val database ="lkl_card_score" val table = "overdue_result_all_new_woe" val date =format.format(new java.util.Date()) //提取数据集 RDD[LabeledPoint] //val data = hc.sql(s"select * from $database.$table").map{ val data = hc.sql(s"select * from lkl_card_score.overdue_result_all_new_woe").map{ row => var arr = new ArrayBuffer[Double]() //剔除label、contact字段 for(i <- 3 until row.size){ if(row.isNullAt(i)){ arr += 0.0 } else if(row.get(i).isInstanceOf[Int]) arr += row.getInt(i).toDouble else if(row.get(i).isInstanceOf[Double]) arr += row.getDouble(i) else if(row.get(i).isInstanceOf[Long]) arr += row.getLong(i).toDouble else if(row.get(i).isInstanceOf[String]) arr += 0.0 } LabeledPoint(row.getInt(0), Vectors.dense(arr.toArray)) } // Split the data into training and test sets (30% held out for testing) val splits = data.randomSplit(Array(0.7, 0.3)) val (trainingData, testData) = (splits(0), splits(1)) // Train a GradientBoostedTrees model. // The defaultParams for Classification use LogLoss by default. val boostingStrategy = BoostingStrategy.defaultParams("Classification") boostingStrategy.setNumIterations(3) // Note: Use more iterations in practice. boostingStrategy.treeStrategy.setNumClasses(2) boostingStrategy.treeStrategy.setMaxDepth(5) // Empty categoricalFeaturesInfo indicates all features are continuous. //boostingStrategy.treeStrategy.setCategoricalFeaturesInfo(Map[Int, Int]()) val model = GradientBoostedTrees.train(trainingData, boostingStrategy) // Evaluate model on test instances and compute test error val predictionAndLabels = testData.map { point => val prediction = model.predict(point.features) (point.label, prediction) } predictionAndLabels.map(x => {"predicts: "+x._1+"--> labels:"+x._2}).saveAsTextFile(s"hdfs://ns1/tmp/$date/predictionAndLabels") //=================================================================== //使用BinaryClassificationMetrics评估模型 val metrics = new BinaryClassificationMetrics(predictionAndLabels) // Precision by threshold val precision = metrics.precisionByThreshold precision.map({case (t, p) => "Threshold: "+t+"Precision:"+p }).saveAsTextFile(s"hdfs://ns1/tmp/$date/precision") // Recall by threshold val recall = metrics.recallByThreshold recall.map({case (t, r) => "Threshold: "+t+"Recall:"+r }).saveAsTextFile(s"hdfs://ns1/tmp/$date/recall") //the beta factor in F-Measure computation. val f1Score = metrics.fMeasureByThreshold f1Score.map(x => {"Threshold: "+x._1+"--> F-score:"+x._2+"--> Beta = 1"}) .saveAsTextFile(s"hdfs://ns1/tmp/$date/f1Score") /** * 如果要选择Threshold, 这三个指标中, 自然F1最为合适 * 求出最大的F1, 对应的threshold就是最佳的threshold */ /*val maxFMeasure = f1Score.select(max("F-Measure")).head().getDouble(0) val bestThreshold = f1Score.where($"F-Measure" === maxFMeasure) .select("threshold").head().getDouble(0)*/ // Precision-Recall Curve val prc = metrics.pr prc.map(x => {"Recall: " + x._1 + "--> Precision: "+x._2 }).saveAsTextFile(s"hdfs://ns1/tmp/$date/prc") // AUPRC,精度,召回曲线下的面积 val auPRC = metrics.areaUnderPR sc.makeRDD(Seq("Area under precision-recall curve = " +auPRC)).saveAsTextFile(s"hdfs://ns1/tmp/$date/auPRC") //roc val roc = metrics.roc roc.map(x => {"FalsePositiveRate:" + x._1 + "--> Recall: " +x._2}).saveAsTextFile(s"hdfs://ns1/tmp/$date/roc") // AUC val auROC = metrics.areaUnderROC sc.makeRDD(Seq("Area under ROC = " + +auROC)).saveAsTextFile(s"hdfs://ns1/tmp/$date/auROC") println("Area under ROC = " + auROC) val testErr = predictionAndLabels.filter(r => r._1 != r._2).count.toDouble / testData.count() sc.makeRDD(Seq("Test Mean Squared Error = " + testErr)).saveAsTextFile(s"hdfs://ns1/tmp/$date/testErr") sc.makeRDD(Seq("Learned regression tree model: " + model.toDebugString)).saveAsTextFile(s"hdfs://ns1/tmp/$date/GBDTclassification") } }