代码:

package com.test

import org.apache.spark.SparkConf
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.sql.SparkSession

object Test03 {
  def main(args: Array[String]): Unit = {

    val conf = new SparkConf()
    conf.setMaster("local")
    val spark = SparkSession.builder().config(conf).appName("Logistic linear regression").getOrCreate()

    import spark.implicits._

    val dataRdd = spark.sparkContext.textFile("data/breast_cancer.csv")
    val data = dataRdd.map(x => {
      val arr = x.split(",")
      val features = new Array[String](arr.length - 1)
      arr.copyToArray(features, 0, arr.length - 1)
      val label = arr(arr.length - 1)
      (new DenseVector(features.map(_.toDouble)), label.toDouble)
    }).toDF("features", "label")

    val splits = data.randomSplit(Array(0.7, 0.3), seed = 11L)

    val (trainingData, testData) = (splits(0), splits(1))

    val lr = new LogisticRegression().setMaxIter(100)

    val lrModel = lr.fit(trainingData)

    println(s"w1~wn: ${lrModel.coefficients} w0: ${lrModel.intercept}")

    //测试集验证正确率
    val testRest = lrModel.transform(testData)
    testRest.show(false)

    // 计算正确率
    val mean = testRest.rdd.map(row => {
      //这个样本真实的分类号
      val label = row.getAs[Double]("label")
      //将测试数据的x特征带入到model后预测出来的分类号
      val prediction = row.getAs[Double]("prediction")
      //0:预测正确   1:预测错了  abs绝对值
      math.abs(label - prediction)
    }).sum()
    println("正确率:" + (1 - (mean / testData.count())))
    // 相当于上面的整个mean计算
    println("正确率:" + lrModel.evaluate(testData).accuracy)


    val count = testRest.rdd.map(row => {
      val probability = row.getAs[DenseVector]("probability")
      val label = row.getAs[Double]("label")
      val score = probability(1)
      val prediction = if (score > 0.3) 1 else 0
      math.abs(label - prediction)
    }).sum()
    println("自定义分类阈值 正确率:" + (1 - (count / testData.count())))
    spark.close()
  }

}

 

结果:

+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-----+----------------------------------------+-----------+----------+
|features                                                                                                                                                                                                                   |label|rawPrediction                           |probability|prediction|
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-----+----------------------------------------+-----------+----------+
|[7.729,25.49,47.98,178.8,0.08098,0.04878,0.0,0.0,0.187,0.07285,0.3777,1.462,2.492,19.14,0.01266,0.009692,0.0,0.0,0.02882,0.006872,9.077,30.92,57.17,248.0,0.1256,0.0834,0.0,0.0,0.3058,0.09938]                            |1.0  |[-6902.544911715737,6902.544911715737]  |[0.0,1.0]  |1.0       |
|[7.76,24.54,47.92,181.0,0.05263,0.04362,0.0,0.0,0.1587,0.05884,0.3857,1.428,2.548,19.15,0.007189,0.00466,0.0,0.0,0.02676,0.002783,9.456,30.37,59.16,268.6,0.08996,0.06444,0.0,0.0,0.2871,0.07039]                          |1.0  |[-4229.951065882061,4229.951065882061]  |[0.0,1.0]  |1.0       |
|[8.618,11.79,54.34,224.5,0.09752,0.05272,0.02061,0.007799,0.1683,0.07187,0.1559,0.5796,1.046,8.322,0.01011,0.01055,0.01981,0.005742,0.0209,0.002788,9.507,15.4,59.9,274.9,0.1733,0.1239,0.1168,0.04419,0.322,0.09026]      |1.0  |[-10129.616436167826,10129.616436167826]|[0.0,1.0]  |1.0       |
|[8.671,14.45,54.42,227.2,0.09138,0.04276,0.0,0.0,0.1722,0.06724,0.2204,0.7873,1.435,11.36,0.009172,0.008007,0.0,0.0,0.02711,0.003399,9.262,17.04,58.36,259.2,0.1162,0.07057,0.0,0.0,0.2592,0.07848]                        |1.0  |[-13336.888611143484,13336.888611143484]|[0.0,1.0]  |1.0       |
|[8.726,15.83,55.84,230.9,0.115,0.08201,0.04132,0.01924,0.1649,0.07633,0.1665,0.5864,1.354,8.966,0.008261,0.02213,0.03259,0.0104,0.01708,0.003806,9.628,19.62,64.48,284.4,0.1724,0.2364,0.2456,0.105,0.2926,0.1017]         |1.0  |[-9268.251080440577,9268.251080440577]  |[0.0,1.0]  |1.0       |
|[8.734,16.84,55.27,234.3,0.1039,0.07428,0.0,0.0,0.1985,0.07098,0.5169,2.079,3.167,28.85,0.01582,0.01966,0.0,0.0,0.01865,0.006736,10.17,22.8,64.01,317.0,0.146,0.131,0.0,0.0,0.2445,0.08865]                                |1.0  |[-10929.784584048108,10929.784584048108]|[0.0,1.0]  |1.0       |
|[9.295,13.9,59.96,257.8,0.1371,0.1225,0.03332,0.02421,0.2197,0.07696,0.3538,1.13,2.388,19.63,0.01546,0.0254,0.02197,0.0158,0.03997,0.003901,10.57,17.84,67.84,326.6,0.185,0.2097,0.09996,0.07262,0.3681,0.08982]           |1.0  |[-9972.504809621163,9972.504809621163]  |[0.0,1.0]  |1.0       |
|[9.333,21.94,59.01,264.0,0.0924,0.05605,0.03996,0.01282,0.1692,0.06576,0.3013,1.879,2.121,17.86,0.01094,0.01834,0.03996,0.01282,0.03759,0.004623,9.845,25.05,62.86,295.8,0.1103,0.08298,0.07993,0.02564,0.2435,0.07393]    |1.0  |[-10874.187824751454,10874.187824751454]|[0.0,1.0]  |1.0       |
|[9.397,21.68,59.75,268.8,0.07969,0.06053,0.03735,0.005128,0.1274,0.06724,0.1186,1.182,1.174,6.802,0.005515,0.02674,0.03735,0.005128,0.01951,0.004583,9.965,27.99,66.61,301.0,0.1086,0.1887,0.1868,0.02564,0.2376,0.09206]  |1.0  |[-9687.163679996896,9687.163679996896]  |[0.0,1.0]  |1.0       |
|[9.676,13.14,64.12,272.5,0.1255,0.2204,0.1188,0.07038,0.2057,0.09575,0.2744,1.39,1.787,17.67,0.02177,0.04888,0.05189,0.0145,0.02632,0.01148,10.6,18.04,69.47,328.1,0.2006,0.3663,0.2913,0.1075,0.2848,0.1364]              |1.0  |[-10425.819568377874,10425.819568377874]|[0.0,1.0]  |1.0       |
|[9.742,19.12,61.93,289.7,0.1075,0.08333,0.008934,0.01967,0.2538,0.07029,0.6965,1.747,4.607,43.52,0.01307,0.01885,0.006021,0.01052,0.031,0.004225,11.21,23.17,71.79,380.9,0.1398,0.1352,0.02085,0.04589,0.3196,0.08009]     |1.0  |[-7470.623970453418,7470.623970453418]  |[0.0,1.0]  |1.0       |
|[9.777,16.99,62.5,290.2,0.1037,0.08404,0.04334,0.01778,0.1584,0.07065,0.403,1.424,2.747,22.87,0.01385,0.02932,0.02722,0.01023,0.03281,0.004638,11.05,21.47,71.68,367.0,0.1467,0.1765,0.13,0.05334,0.2533,0.08468]          |1.0  |[-9631.107650152822,9631.107650152822]  |[0.0,1.0]  |1.0       |
|[9.787,19.94,62.11,294.5,0.1024,0.05301,0.006829,0.007937,0.135,0.0689,0.335,2.043,2.132,20.05,0.01113,0.01463,0.005308,0.00525,0.01801,0.005667,10.92,26.29,68.81,366.1,0.1316,0.09473,0.02049,0.02381,0.1934,0.08988]    |1.0  |[-10574.948203901578,10574.948203901578]|[0.0,1.0]  |1.0       |
|[10.03,21.28,63.19,307.3,0.08117,0.03912,0.00247,0.005159,0.163,0.06439,0.1851,1.341,1.184,11.6,0.005724,0.005697,0.002074,0.003527,0.01445,0.002411,11.11,28.94,69.92,376.3,0.1126,0.07094,0.01235,0.02579,0.2349,0.08061]|1.0  |[-8893.557477503311,8893.557477503311]  |[0.0,1.0]  |1.0       |
|[10.08,15.11,63.76,317.5,0.09267,0.04695,0.001597,0.002404,0.1703,0.06048,0.4245,1.268,2.68,26.43,0.01439,0.012,0.001597,0.002404,0.02538,0.00347,11.87,21.18,75.39,437.0,0.1521,0.1019,0.00692,0.01042,0.2933,0.07697]    |1.0  |[-5668.804179132235,5668.804179132235]  |[0.0,1.0]  |1.0       |
|[10.17,14.88,64.55,311.9,0.1134,0.08061,0.01084,0.0129,0.2743,0.0696,0.5158,1.441,3.312,34.62,0.007514,0.01099,0.007665,0.008193,0.04183,0.005953,11.02,17.45,69.86,368.6,0.1275,0.09866,0.02168,0.02579,0.3557,0.0802]    |1.0  |[-11852.030184143723,11852.030184143723]|[0.0,1.0]  |1.0       |
|[10.32,16.35,65.31,324.9,0.09434,0.04994,0.01012,0.005495,0.1885,0.06201,0.2104,0.967,1.356,12.97,0.007086,0.007247,0.01012,0.005495,0.0156,0.002606,11.25,21.77,71.12,384.9,0.1285,0.08842,0.04384,0.02381,0.2681,0.07399]|1.0  |[-9073.703438558807,9073.703438558807]  |[0.0,1.0]  |1.0       |
|[10.48,19.86,66.72,337.7,0.107,0.05971,0.04831,0.0307,0.1737,0.0644,0.3719,2.612,2.517,23.22,0.01604,0.01386,0.01865,0.01133,0.03476,0.00356,11.48,29.46,73.68,402.8,0.1515,0.1026,0.1181,0.06736,0.2883,0.07748]          |1.0  |[-4389.849643347756,4389.849643347756]  |[0.0,1.0]  |1.0       |
|[10.51,23.09,66.85,334.2,0.1015,0.06797,0.02495,0.01875,0.1695,0.06556,0.2868,1.143,2.289,20.56,0.01017,0.01443,0.01861,0.0125,0.03464,0.001971,10.93,24.22,70.1,362.7,0.1143,0.08614,0.04158,0.03125,0.2227,0.06777]      |1.0  |[-10834.90580709216,10834.90580709216]  |[0.0,1.0]  |1.0       |
|[10.57,20.22,70.15,338.3,0.09073,0.166,0.228,0.05941,0.2188,0.0845,0.1115,1.231,2.363,7.228,0.008499,0.07643,0.1535,0.02919,0.01617,0.0122,10.85,22.82,76.51,351.9,0.1143,0.3619,0.603,0.1465,0.2597,0.12]                 |1.0  |[-9212.082801976405,9212.082801976405]  |[0.0,1.0]  |1.0       |
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-----+----------------------------------------+-----------+----------+
only showing top 20 rows
正确率:0.9766081871345029
正确率:0.9766081871345029

自定义分类阈值 正确率:0.9766081871345029

 

posted on 2021-04-11 14:23  陕西小楞娃  阅读(94)  评论(0编辑  收藏  举报