lkl风控.随机森林模型测试代码spark1.6
/** * Created by lkl on 2017/10/9. */ import org.apache.spark.sql.hive.HiveContext import org.apache.spark.SparkConf import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkContext import org.apache.spark.mllib.tree.RandomForest import org.apache.spark.mllib.tree.model.RandomForestModel import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.sql.SQLContext import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator object uvcy { def main(args: Array[String]) { val conf = new SparkConf().setAppName("test") //setMaster("spark://192.168.0.37:7077") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) val hc = new HiveContext(sc) val data2 = hc.sql("select * from fin_tec.uvcy2") //第一个字段为身份证号,第二个字段为是否逾期,字符存在在hive中全部为double型 val data = data2.map{ row => val arr = new ArrayBuffer[Double]() for(i <- 2 until row.size){ if(row.isNullAt(i)){ arr += 0.0} else if(row.get(i).isInstanceOf[Double]) arr += row.getDouble(i) else if(row.get(i).isInstanceOf[Long]) arr += row.getLong(i).toDouble else if(row.get(i).isInstanceOf[String]) arr += row.getString(i).toDouble} LabeledPoint(row.getDouble(1), Vectors.dense(arr.toArray))} val splits = data.randomSplit(Array(0.7, 0.3)) val (trainingData, testData) = (splits(0), splits(1)) val numClasses = 2 val categoricalFeaturesInfo = Map[Int, Int]() val numTrees = 3 val featureSubsetStrategy = "auto" val impurity = "gini" val maxDepth = 4 val maxBins = 32 val model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins) val labelAndPreds = testData.map { point => val prediction = model.predict(point.features) (point.label, prediction) } val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("precision") val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() println("Test Error = " + testErr) println("Learned classification forest model:\n" + model.toDebugString) model.save(sc, "uvcymodel/forest") val sameModel = RandomForestModel.load(sc, "uvcymodel/forest") val data3 = hc.sql("select * from test.uvcy where i_l3_hk_amt=2150") val id="110101000000000000" val datas = data3.map{ row => val arr = new ArrayBuffer[Double]() for(i <- 2 until row.size){ if(row.isNullAt(i)){ arr += 0.0} else if(row.get(i).isInstanceOf[Double]) arr += row.getDouble(i) else if(row.get(i).isInstanceOf[Long]) arr += row.getLong(i).toDouble else if(row.get(i).isInstanceOf[String]) arr += row.getString(i).toDouble} (Vectors.dense(arr.toArray))} val labelAndPreds2 = testData.map { point => val prediction =sameModel.predict(point.features) (id,point.label, prediction,point.features) } labelAndPreds2.take(2) } }