Spark ML 之 LR逻辑回归实现排序
一、理论
https://www.jianshu.com/p/114100d0517f
https://www.imooc.com/article/46843
二、代码
1、准备数据
2、数据分成 train和test进行测试:用train的数据训练(fit)出的model带入(transform)test数据
验证label和predict的是否足够精确
3、排序
package com.njbdqn import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} import org.apache.spark.ml.linalg.Vectors import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions._ /** * 排序:LR */ object LRtest { val positive = udf{ (vc:String)=>{ vc.replaceAll("\\[|\\]","").split(",")(1).toDouble }} def main(args: Array[String]): Unit = { val spark = SparkSession.builder().appName("app").master("local[*]").getOrCreate() val data = spark.createDataFrame(Seq( ("1","2",1.0, Vectors.dense(0.0, 1.1, 0.1)), ("1","2",0.0, Vectors.dense(2.0, 1.0, -1.1)), ("1","2",1.0, Vectors.dense(1.0, 2.1, 0.1)), ("1","2",0.0, Vectors.dense(2.0, -1.3, 1.1)), ("1","2",0.0, Vectors.dense(2.0, 1.0, -1.1)), ("1","2",1.0, Vectors.dense(1.0, 2.1, 0.1)), ("1","2",1.0, Vectors.dense(2.0, 1.3, 1.1)), ("1","2",0.0, Vectors.dense(-2.0, 1.0, -1.1)), ("1","2",1.0, Vectors.dense(1.0, 2.1, 0.1)), ("1","2",0.0, Vectors.dense(2.0, -1.3, 1.1)), ("1","2",1.0, Vectors.dense(2.0, 1.0, -1.1)), ("1","2",1.0, Vectors.dense(1.0, 2.1, 0.1)), ("1","2",0.0, Vectors.dense(-2.0, 1.3, 1.1)), ("1","2",1.0, Vectors.dense(0.0, 1.2, -0.4)) )).toDF("user","goods","label","features") //.show(false) val Array(train,test) = data.randomSplit(Array(0.7,0.3)) // 设置训练模型的超参 val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.01) // 训练模型 val model = lr.fit(train) // 把模型存到HDFS // model.save("hdfs://192.168.56.111:9000/LRmodel") // 获取HDFS上的模型 val model2 = LogisticRegressionModel.load("hdfs://192.168.56.111:9000/LRmodel") // 检测模型的准确性 // val preRes = model.transform(test) // preRes.show(false) val res = model2.transform(data) import spark.implicits._ // 方法一:死办法,不推荐 // probability:[xxx,xxx],后面的数据是感兴趣的程度,超过0.5则predict为1 res.withColumn("pro",$"probability".cast("String")) .select($"user",$"goods",positive($"pro").alias("score")) .orderBy(desc("score")).show(false) // 方法二:推荐,模式匹配方法 res.select("user","goods","probability") .rdd.map{case(Row(uid:Double,gid:Double,score:DenseVector))=>(uid,gid,score(1))} .toDF("user","goods","score") .select($"user",$"goods",row_number().over(wnd).alias("rank")) .show(false) spark.stop() } }
结果:
+----+-----+-------------------+
|user|goods|score |
+----+-----+-------------------+
|1 |2 |0.9473385564891683 |
|1 |2 |0.9473385564891683 |
|1 |2 |0.9473385564891683 |
|1 |2 |0.9473385564891683 |
|1 |2 |0.9202855138287962 |
|1 |2 |0.5337766179253915 |
|1 |2 |0.5337766179253915 |
|1 |2 |0.5337766179253915 |
|1 |2 |0.5081492680443979 |
|1 |2 |0.5014483932183084 |
|1 |2 |0.4713578993198038 |
|1 |2 |0.09069927610736443|
|1 |2 |0.03241657419240436|
|1 |2 |0.03241657419240436|
+----+-----+-------------------+