Spark 逻辑回归LogisticRegression
1、概念
逻辑回归是预测分类相应的常用方法。广义线性回归的一个特例是预测结果的概率。在spark.ml逻辑回归中,可以使用二项逻辑回归来预测二元结果,
或者可以使用多项逻辑回归来预测多类结果。使用该family参数在这两种算法之间选择,或者保持不设置(缺省auto),Spark将推断出正确的变量。 通过将family参数设置为“多项式”,可以将多项逻辑回归用于二进制分类。它将产生两组系数和两个截距.
在分类问题中,我们尝试预测的是结果是否属于某一个类(例如正确或错误)。分类问题的例子有:判断一封电子邮件是否是垃圾邮件;判断一次金融交易是否是欺诈;
2、code,参考地址:https://github.com/asker124143222/spark-demo
package com.home.spark.ml import org.apache.spark.SparkConf import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.sql.{Dataset, Row, SparkSession} /** * @Description: 逻辑回归,二项分类预测 * **/ object Ex_BinomialLogisticRegression { def main(args: Array[String]): Unit = { val conf = new SparkConf(true).setMaster("local[*]").setAppName("spark ml label") val spark = SparkSession.builder().config(conf).getOrCreate() //rdd转换成df或者ds需要SparkSession实例的隐式转换 //导入隐式转换,注意这里的spark不是包名,而是SparkSession的对象名 import spark.implicits._ val data = spark.sparkContext.textFile("input/iris.data.txt") .map(_.split(",")) .map(a => Iris( Vectors.dense(a(0).toDouble, a(1).toDouble, a(2).toDouble, a(3).toDouble), a(4)) ).toDF() data.show() data.createOrReplaceTempView("iris") val TotalCount = spark.sql("select count(*) from iris") println("记录数: " + TotalCount.collect().take(1).mkString) //二项预测,由于样本数据有三类数据,排除Iris-setosa val df = spark.sql("select * from iris where label!='Iris-setosa'") df.map(r => r(1) + " : " + r(0)).collect().take(10).foreach(println) println("过滤后的记录数: " + df.count()) /* VectorIndexer 提高决策树或随机森林等ML方法的分类效果。 VectorIndexer是对数据集特征向量中的类别(离散值)特征(index categorical features categorical features )进行编号。 它能够自动判断那些特征是离散值型的特征,并对他们进行编号, 具体做法是通过设置一个maxCategories,特征向量中某一个特征不重复取值个数小于maxCategories,则被重新编号为0~K(K<=maxCategories-1)。 某一个特征不重复取值个数大于maxCategories,则该特征视为连续值,不会重新编号(不会发生任何改变) 假设maxCategories=5,那么特征列中非重复取值小于等于5的列将被重新索引 为了索引的稳定性,规定如果这个特征值为0,则一定会被编号成0,这样可以保证向量的稀疏度 maxCategories缺省是20 */ //对特征列和标签列进行索引转换 val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(df) val featureIndexer = new VectorIndexer() // .setMaxCategories(5) //设置为5后,由于特征列的非重复值个数都大于5,所以不会发生任何转换,也就没有意义 .setInputCol("features").setOutputCol("indexedFeatures") .fit(df) //对原数据集划分训练数据(70%)和测试数据(30%) val Array(trainingData, testData): Array[Dataset[Row]] = df.randomSplit(Array(0.7, 0.3)) /** * LR建模 * setMaxIter设置最大迭代次数(默认100),具体迭代次数可能在不足最大迭代次数停止 * setTol设置容错(默认1e-6),每次迭代会计算一个误差,误差值随着迭代次数增加而减小,当误差小于设置容错,则停止迭代 * setRegParam设置正则化项系数(默认0),正则化主要用于防止过拟合现象,如果数据集较小,特征维数又多,易出现过拟合,考虑增大正则化系数 * setElasticNetParam正则化范式比(默认0),正则化有两种方式:L1(Lasso)和L2(Ridge),L1用于特征的稀疏化,L2用于防止过拟合 * setLabelCol设置标签列 * setFeaturesCol设置特征列 * setPredictionCol设置预测列 * setThreshold设置二分类阈值 */ //设置逻辑回归参数 val lr = new LogisticRegression().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setFamily() .setMaxIter(100).setRegParam(0.3).setElasticNetParam(0.8) //转换器,将预测的类别重新转成字符型 val labelConverter = new IndexToString() .setInputCol("prediction") .setOutputCol("predectionLabel") .setLabels(labelIndexer.labels) //建立工作流 val lrPipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, lr, labelConverter)) //生成模型 val model = lrPipeline.fit(trainingData) //预测 val result = model.transform(testData) //打印结果 result.show(200, false) //模型评估,预测准确性和错误率 val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction") val lrAccuracy: Double = evaluator.evaluate(result) println("Test Error = " + (1.0 - lrAccuracy)) spark.stop() } } case class Iris(features: Vector, label: String)
3、result
+-----------------+---------------+------------+-------------------+--------------------------------------------+----------------------------------------+----------+---------------+ |features |label |indexedLabel|indexedFeatures |rawPrediction |probability |prediction|predectionLabel| +-----------------+---------------+------------+-------------------+--------------------------------------------+----------------------------------------+----------+---------------+ |[4.9,2.4,3.3,1.0]|Iris-versicolor|0.0 |[4.9,3.0,3.3,0.0] |[1.0071037675553336,-1.0071037675553336] |[0.7324529695042751,0.2675470304957249] |0.0 |Iris-versicolor| |[5.0,2.0,3.5,1.0]|Iris-versicolor|0.0 |[5.0,0.0,3.5,0.0] |[0.938177922699384,-0.938177922699384] |[0.7187314594034615,0.2812685405965385] |0.0 |Iris-versicolor| |[5.6,2.5,3.9,1.1]|Iris-versicolor|0.0 |[5.6,4.0,3.9,1.0] |[0.7107814076350716,-0.7107814076350716] |[0.6705737993354417,0.3294262006645583] |0.0 |Iris-versicolor| |[5.6,2.9,3.6,1.3]|Iris-versicolor|0.0 |[5.6,8.0,3.6,3.0] |[0.6350805242141693,-0.6350805242141693] |[0.6536405613705153,0.3463594386294846] |0.0 |Iris-versicolor| |[5.8,2.7,4.1,1.0]|Iris-versicolor|0.0 |[5.8,6.0,4.1,0.0] |[0.7314003881315354,-0.7314003881315354] |[0.6751125028597408,0.32488749714025916]|0.0 |Iris-versicolor| |[6.1,2.8,4.7,1.2]|Iris-versicolor|0.0 |[6.1,7.0,4.7,2.0] |[0.34553320285886,-0.34553320285886] |[0.5855339747983552,0.41446602520164466]|0.0 |Iris-versicolor| |[6.2,2.2,4.5,1.5]|Iris-versicolor|0.0 |[6.2,1.0,4.5,5.0] |[0.14582457165756946,-0.14582457165756946] |[0.5363916772629104,0.46360832273708963]|0.0 |Iris-versicolor| |[6.4,2.9,4.3,1.3]|Iris-versicolor|0.0 |[6.4,8.0,4.3,3.0] |[0.39384006721834597,-0.39384006721834597] |[0.597206774507057,0.40279322549294305] |0.0 |Iris-versicolor| |[6.6,3.0,4.4,1.4]|Iris-versicolor|0.0 |[6.6,9.0,4.4,4.0] |[0.2698323194379575,-0.2698323194379575] |[0.5670517391689078,0.43294826083109217]|0.0 |Iris-versicolor| |[6.7,3.0,5.0,1.7]|Iris-versicolor|0.0 |[6.7,9.0,5.0,7.0] |[-0.20557969118713126,0.20557969118713126] |[0.44878532413929256,0.5512146758607075]|1.0 |Iris-virginica | |[6.7,3.1,4.4,1.4]|Iris-versicolor|0.0 |[6.7,10.0,4.4,4.0] |[0.2698323194379575,-0.2698323194379575] |[0.5670517391689078,0.43294826083109217]|0.0 |Iris-versicolor| |[7.0,3.2,4.7,1.4]|Iris-versicolor|0.0 |[7.0,11.0,4.7,4.0] |[0.16644355215403328,-0.16644355215403328] |[0.5415150896404186,0.4584849103595813] |0.0 |Iris-versicolor| |[4.9,2.5,4.5,1.7]|Iris-virginica |1.0 |[4.9,4.0,4.5,7.0] |[-0.033265079047257284,0.033265079047257284]|[0.49168449702809164,0.5083155029719083]|1.0 |Iris-virginica | |[5.4,3.0,4.5,1.5]|Iris-versicolor|0.0 |[5.4,9.0,4.5,5.0] |[0.14582457165756946,-0.14582457165756946] |[0.5363916772629104,0.46360832273708963]|0.0 |Iris-versicolor| |[5.6,2.8,4.9,2.0]|Iris-virginica |1.0 |[5.6,7.0,4.9,10.0] |[-0.43975124481639627,0.43975124481639627] |[0.39180024423019144,0.6081997557698086]|1.0 |Iris-virginica | |[5.6,3.0,4.1,1.3]|Iris-versicolor|0.0 |[5.6,9.0,4.1,3.0] |[0.4627659120742955,-0.4627659120742955] |[0.6136701219061476,0.38632987809385244]|0.0 |Iris-versicolor| |[5.8,2.7,3.9,1.2]|Iris-versicolor|0.0 |[5.8,6.0,3.9,2.0] |[0.6212365822826582,-0.6212365822826582] |[0.6504997376392441,0.34950026236075604]|0.0 |Iris-versicolor| |[5.8,2.7,5.1,1.9]|Iris-virginica |1.0 |[5.8,6.0,5.1,9.0] |[-0.419132264319932,0.419132264319932] |[0.3967244102962335,0.6032755897037665] |1.0 |Iris-virginica | |[5.9,3.0,5.1,1.8]|Iris-virginica |1.0 |[5.9,9.0,5.1,8.0] |[-0.32958743896751885,0.32958743896751885] |[0.4183410089972438,0.5816589910027563] |1.0 |Iris-virginica | |[6.0,2.9,4.5,1.5]|Iris-versicolor|0.0 |[6.0,8.0,4.5,5.0] |[0.14582457165756946,-0.14582457165756946] |[0.5363916772629104,0.46360832273708963]|0.0 |Iris-versicolor| |[6.1,3.0,4.6,1.4]|Iris-versicolor|0.0 |[6.1,9.0,4.6,4.0] |[0.20090647458200817,-0.20090647458200817] |[0.5500583546439539,0.4499416453560461] |0.0 |Iris-versicolor| |[6.2,3.4,5.4,2.3]|Iris-virginica |1.0 |[6.2,13.0,5.4,13.0]|[-0.8807003330135101,0.8807003330135101] |[0.29303267372325625,0.7069673262767437]|1.0 |Iris-virginica | |[6.7,3.1,4.7,1.5]|Iris-versicolor|0.0 |[6.7,10.0,4.7,5.0] |[0.07689872680162013,-0.07689872680162013] |[0.5192152136737482,0.48078478632625177]|0.0 |Iris-versicolor| |[6.7,3.3,5.7,2.5]|Iris-virginica |1.0 |[6.7,12.0,5.7,15.0]|[-1.163178751002261,1.163178751002261] |[0.23809016943453823,0.7619098305654617]|1.0 |Iris-virginica | |[6.8,3.0,5.5,2.1]|Iris-virginica |1.0 |[6.8,9.0,5.5,11.0] |[-0.7360736047366578,0.7360736047366578] |[0.32386333429517283,0.6761366657048272]|1.0 |Iris-virginica | |[6.9,3.1,5.4,2.1]|Iris-virginica |1.0 |[6.9,10.0,5.4,11.0]|[-0.7016106823086834,0.7016106823086834] |[0.33145521561995817,0.6685447843800418]|1.0 |Iris-virginica | |[7.2,3.6,6.1,2.5]|Iris-virginica |1.0 |[7.2,14.0,6.1,15.0]|[-1.3010304407141597,1.3010304407141597] |[0.21399164655179387,0.7860083534482062]|1.0 |Iris-virginica | |[7.7,2.8,6.7,2.0]|Iris-virginica |1.0 |[7.7,7.0,6.7,10.0] |[-1.0600838485199424,1.0600838485199424] |[0.2572934314622856,0.7427065685377143] |1.0 |Iris-virginica | |[7.7,3.0,6.1,2.3]|Iris-virginica |1.0 |[7.7,9.0,6.1,13.0] |[-1.1219407900093334,1.1219407900093334] |[0.24565146441425778,0.7543485355857422]|1.0 |Iris-virginica | |[7.9,3.8,6.4,2.0]|Iris-virginica |1.0 |[7.9,15.0,6.4,10.0]|[-0.9566950812360182,0.9566950812360182] |[0.2775403823663211,0.7224596176336789] |1.0 |Iris-virginica | +-----------------+---------------+------------+-------------------+--------------------------------------------+----------------------------------------+----------+---------------+ Test Error = 0.03314285714285714