Spark ML 之 LR逻辑回归实现排序

一、理论

https://www.jianshu.com/p/114100d0517f

https://www.imooc.com/article/46843

二、代码

1、准备数据

2、数据分成 train和test进行测试:用train的数据训练(fit)出的model带入(transform)test数据

验证label和predict的是否足够精确

3、排序

package com.njbdqn

import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._

/**
 * 排序:LR
 */
object LRtest {
  val positive = udf{
    (vc:String)=>{
      vc.replaceAll("\\[|\\]","").split(",")(1).toDouble
  }}

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().appName("app").master("local[*]").getOrCreate()
    val data = spark.createDataFrame(Seq(
      ("1","2",1.0, Vectors.dense(0.0, 1.1, 0.1)),
      ("1","2",0.0, Vectors.dense(2.0, 1.0, -1.1)),
      ("1","2",1.0, Vectors.dense(1.0, 2.1, 0.1)),
      ("1","2",0.0, Vectors.dense(2.0, -1.3, 1.1)),
      ("1","2",0.0, Vectors.dense(2.0, 1.0, -1.1)),
      ("1","2",1.0, Vectors.dense(1.0, 2.1, 0.1)),
      ("1","2",1.0, Vectors.dense(2.0, 1.3, 1.1)),
      ("1","2",0.0, Vectors.dense(-2.0, 1.0, -1.1)),
      ("1","2",1.0, Vectors.dense(1.0, 2.1, 0.1)),
      ("1","2",0.0, Vectors.dense(2.0, -1.3, 1.1)),
      ("1","2",1.0, Vectors.dense(2.0, 1.0, -1.1)),
      ("1","2",1.0, Vectors.dense(1.0, 2.1, 0.1)),
      ("1","2",0.0, Vectors.dense(-2.0, 1.3, 1.1)),
      ("1","2",1.0, Vectors.dense(0.0, 1.2, -0.4))
    )).toDF("user","goods","label","features")
      //.show(false)
    val Array(train,test) = data.randomSplit(Array(0.7,0.3))
    // 设置训练模型的超参
    val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.01)
    // 训练模型
    val model = lr.fit(train)
    // 把模型存到HDFS
  //  model.save("hdfs://192.168.56.111:9000/LRmodel")
    // 获取HDFS上的模型
   val model2 = LogisticRegressionModel.load("hdfs://192.168.56.111:9000/LRmodel")
    // 检测模型的准确性
//    val preRes = model.transform(test)
//    preRes.show(false)
    val res = model2.transform(data)
    import spark.implicits._
// 方法一:死办法,不推荐    
// probability:[xxx,xxx],后面的数据是感兴趣的程度,超过0.5则predict为1
    res.withColumn("pro",$"probability".cast("String"))
        .select($"user",$"goods",positive($"pro").alias("score"))
        .orderBy(desc("score")).show(false)
// 方法二:推荐,模式匹配方法
    res.select("user","goods","probability")
      .rdd.map{case(Row(uid:Double,gid:Double,score:DenseVector))=>(uid,gid,score(1))}
      .toDF("user","goods","score")
      .select($"user",$"goods",row_number().over(wnd).alias("rank"))
      .show(false)
spark.stop()
  }
}

结果:

+----+-----+-------------------+
|user|goods|score |
+----+-----+-------------------+
|1 |2 |0.9473385564891683 |
|1 |2 |0.9473385564891683 |
|1 |2 |0.9473385564891683 |
|1 |2 |0.9473385564891683 |
|1 |2 |0.9202855138287962 |
|1 |2 |0.5337766179253915 |
|1 |2 |0.5337766179253915 |
|1 |2 |0.5337766179253915 |
|1 |2 |0.5081492680443979 |
|1 |2 |0.5014483932183084 |
|1 |2 |0.4713578993198038 |
|1 |2 |0.09069927610736443|
|1 |2 |0.03241657419240436|
|1 |2 |0.03241657419240436|
+----+-----+-------------------+

posted @ 2020-10-25 17:56  PEAR2020  阅读(472)  评论(0编辑  收藏  举报