Spark ML 之 推荐算法项目(上)

一、整体流程

二、具体召回流程

三、代码实现

 0、过滤已下架的/成人用品/烟酒等

package com.njbdqn.filter

import com.njbdqn.util.{HDFSConnection, MYSQLConnection}
import org.apache.spark.sql.SparkSession

object BanGoodFilter {

  /**
   *  清洗不能推荐的商品,把留下的商品存在HDFS上保存
   * @param spark
   */
  def ban(spark:SparkSession): Unit ={
   // 读出原始数据
   val goodsDf = MYSQLConnection.readMySql(spark, "goods")
    // 过滤下架商品(已经卖过),把未卖的商品存放到HDFS
    val gd = goodsDf.filter("is_sale=0")
    HDFSConnection.writeDataToHDFS("/myshops/dwd_good",gd)
  }
}

 

 1、根据热点全局召回,cross join到每个用户(使每个用户都有可以推荐的)

package com.njbdqn.call

import com.njbdqn.util.{HDFSConnection, MYSQLConnection}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{desc, row_number, sum}

/**
 *  全局召回
 */
object GlobalHotCall {

  def hotSell(spark:SparkSession): Unit ={
    val oitab = MYSQLConnection.readMySql(spark, "orderItems").cache()
    // 计算全局热卖商品前100名 ( good_id,sellnum )
    import spark.implicits._
    val top100 = oitab
      .groupBy("good_id")
      .agg(sum("buy_num").alias("sellnum"))
      .withColumn("rank",row_number().over(Window.orderBy(desc("sellnum"))))
      .limit(100)
    // 所有用户id和推荐前30名cross join
    val custstab = MYSQLConnection.readMySql(spark,"customs")
      .select($"cust_id").cache()
    val res = custstab.crossJoin(top100)
    HDFSConnection.writeDataToHDFS("/myshops/dwd_hotsell",res)
    // 针对游客:热卖前10放到 mysql的 hotsell
    val hotsell =top100.limit(10)
    MYSQLConnection.writeTable(spark,hotsell,"hotsell")
  }
}

 2、分组召回

详细见 https://www.cnblogs.com/sabertobih/p/13824739.html

数据处理,归一化:

package com.njbdqn.datahandler

import com.njbdqn.util.{HDFSConnection, MYSQLConnection}
import org.apache.spark.ml.clustering.KMeans
import org.apache.spark.ml.feature.{MinMaxScaler, StringIndexer, VectorAssembler}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{col, count, current_date, datediff, desc, min, row_number, sum, udf}
import org.apache.spark.sql.types.DoubleType

object KMeansHandler {
  val func_membership = udf {
    (score: Int) => {
      score match {
        case i if i < 100 => 1
        case i if i < 500 => 2
        case i if i < 1000 => 3
        case _ => 4
      }
    }
  }
  val func_bir = udf {
    (idno: String, now: String) => {
      val year = idno.substring(6, 10).toInt
      val month = idno.substring(10, 12).toInt
      val day = idno.substring(12, 14).toInt

      val dts = now.split("-")
      val nowYear = dts(0).toInt
      val nowMonth = dts(1).toInt
      val nowDay = dts(2).toInt

      if (nowMonth > month) {
        nowYear - year
      } else if (nowMonth < month) {
        nowYear - 1 - year
      } else {
        if (nowDay >= day) {
          nowYear - year
        } else {
          nowYear - 1 - year
        }
      }
    }
  }
  val func_age = udf {
    (num: Int) => {
      num match {
        case n if n < 10 => 1
        case n if n < 18 => 2
        case n if n < 23 => 3
        case n if n < 35 => 4
        case n if n < 50 => 5
        case n if n < 70 => 6
        case _ => 7
      }
    }
  }
  val func_userscore = udf {
    (sc: Int) => {
      sc match {
        case s if s < 100 => 1
        case s if s < 500 => 2
        case _ => 3
      }
    }
  }
  val func_logincount = udf {
    (sc: Int) => {
      sc match {
        case s if s < 500 => 1
        case _ => 2
      }
    }
  }

  // 整合用户自然属性和行为
  def user_act_info(spark:SparkSession): DataFrame ={
    val featureDataTable = MYSQLConnection.readMySql(spark,"customs").filter("active!=0")
      .select("cust_id", "company", "province_id", "city_id", "district_id"
        , "membership_level", "create_at", "last_login_time", "idno", "biz_point", "sex", "marital_status", "education_id"
        , "login_count", "vocation", "post")
    //商品表
    val goodTable=HDFSConnection.readDataToHDFS(spark,"/myshops/dwd_good").select("good_id","price")
    //订单表
    val orderTable=MYSQLConnection.readMySql(spark,"orders").select("ord_id","cust_id")
    //订单明细表
    val orddetailTable=MYSQLConnection.readMySql(spark,"orderItems").select("ord_id","good_id","buy_num")
    //先将公司名通过StringIndex转为数字
    val compIndex = new StringIndexer().setInputCol("company").setOutputCol("compId")
    //使用自定义UDF函数
    import spark.implicits._
    //计算每个用户购买的次数
    val tmp_bc=orderTable.groupBy("cust_id").agg(count($"ord_id").as("buycount"))
    //计算每个用户在网站上花费了多少钱
    val tmp_pay=orderTable.join(orddetailTable,Seq("ord_id"),"inner").join(goodTable,Seq("good_id"),"inner").groupBy("cust_id").
      agg(sum($"buy_num"*$"price").as("pay"))

    compIndex.fit(featureDataTable).transform(featureDataTable)
      .withColumn("mslevel", func_membership($"membership_level"))
      .withColumn("min_reg_date", min($"create_at") over())
      .withColumn("reg_date", datediff($"create_at", $"min_reg_date"))
      .withColumn("min_login_time", min("last_login_time") over())
      .withColumn("lasttime", datediff($"last_login_time", $"min_login_time"))
      .withColumn("age", func_age(func_bir($"idno", current_date())))
      .withColumn("user_score", func_userscore($"biz_point"))
      .withColumn("logincount", func_logincount($"login_count"))
      // 右表:有的用户可能没有买/没花钱,缺少cust_id,所以是left join,以多的为准
      .join(tmp_bc,Seq("cust_id"),"left").join(tmp_pay,Seq("cust_id"),"left")
      .na.fill(0)
      .drop("company", "membership_level", "create_at", "min_reg_date"
        , "last_login_time", "min_login_time", "idno", "biz_point", "login_count")
  }
  // 用户分组
  def user_group(spark:SparkSession) = {
    val df = user_act_info(spark)
    //将所有列换成 Double
    val columns=df.columns.map(f=>col(f).cast(DoubleType))
    val num_fmt=df.select(columns:_*)
    //将除了第一列的所有列都组装成一个向量列
    val va= new VectorAssembler()
      .setInputCols(Array("province_id","city_id","district_id","sex","marital_status","education_id","vocation","post","compId","mslevel","reg_date","lasttime","age","user_score","logincount","buycount","pay"))
      .setOutputCol("orign_feature")
    val ofdf=va.transform(num_fmt).select("cust_id","orign_feature")
    //将原始特征列归一化处理
    val mmScaler:MinMaxScaler=new MinMaxScaler().setInputCol("orign_feature").setOutputCol("feature")
    //fit产生模型 把ofdf放到模型里使用
    mmScaler.fit(ofdf)
      .transform(ofdf)
      .select("cust_id","feature")

  }
}

kmeans计算分组召回:

package com.njbdqn.call

import com.njbdqn.datahandler.KMeansHandler
import com.njbdqn.util.{HDFSConnection, MYSQLConnection}
import org.apache.spark.ml.clustering.KMeans
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._


/**
 *  计算用户分组
 */
object GroupCall {
  def calc_groups(spark:SparkSession): Unit ={
    //使用Kmeans算法进行分组
    //计算根据不同的质心点计算所有的距离
    //记录不同质心点距离的集合
    //    val disList:ListBuffer[Double]=ListBuffer[Double]()
    //    for (i<-2 to 40){
    //      val kms=new KMeans().setFeaturesCol("feature").setK(i)
    //      val model=kms.fit(resdf)
    //    // 为什么不transform ??
    //      // 目的不是产生df:cust_id,feature和对应的group(prediction)
    //      // 目的是用computeCost算K数量对应的[SSD]
    //      disList.append(model.computeCost(resdf))
    //    }
    //    //调用绘图工具绘图
    //    val chart=new LineGraph("app","Kmeans质心和距离",disList)
    //    chart.pack()
    //    RefineryUtilities.centerFrameOnScreen(chart)
    //    chart.setVisible(true)

    import spark.implicits._
    val orderTable=MYSQLConnection.readMySql(spark,"orders").select("ord_id","cust_id")
    val orddetailTable=MYSQLConnection.readMySql(spark,"orderItems").select("ord_id","good_id","buy_num")
    val resdf = KMeansHandler.user_group(spark)
        //使用 Kmeans 进行分组:找一个稳定的 K 值
        val kms=new KMeans().setFeaturesCol("feature").setK(40)
        // 每个用户所属的组 (cust_id,groups) (1,0)
        val user_group_tab=kms.fit(resdf)
          .transform(resdf)
          .drop("feature").
          withColumnRenamed("prediction","groups").cache()

        //获取每组用户购买的前30名商品
        // row_number 根据组分组,买的次数desc
        // groupby 组和商品,count买的次数order_id
        val rank=30
        val wnd=Window.partitionBy("groups").orderBy(desc("group_buy_count"))

    val groups_goods = user_group_tab.join(orderTable,Seq("cust_id"),"inner")
         .join(orddetailTable,Seq("ord_id"),"inner")
          .na.fill(0)
          .groupBy("groups","good_id")
          .agg(count("ord_id").as("group_buy_count"))
          .withColumn("rank",row_number()over(wnd))
          .filter($"rank"<=rank)
        // 每个用户所属组推荐的商品(是为每个用户推荐的)
    val df5 = user_group_tab.join(groups_goods,Seq("groups"),"inner")
          HDFSConnection.writeDataToHDFS("/myshops/dwd_kMeans",df5)
  }
}

 3、ALS协同过滤召回

ALS数据预处理:User-Item稀疏矩阵中score需要量化成数字,每列都需要全数字,稀疏表=> Rating集合

package com.njbdqn.datahandler

import com.njbdqn.util.{HDFSConnection, MYSQLConnection}
import org.apache.spark.mllib.recommendation.Rating
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{row_number, sum, udf}

object ALSDataHandler {
  // 为了防止用户编号或商品编号中含有非数字情况,要对所有的商品和用户编号给一个连续的对应的数字编号后再存到缓存
  def goods_to_num(spark:SparkSession):DataFrame={
    import spark.implicits._
    val wnd1 = Window.orderBy("good_id")
    HDFSConnection.readDataToHDFS(spark,"/myshops/dwd_good").select("good_id","price")
      .select($"good_id",row_number().over(wnd1).alias("gid")).cache()
  }

  def user_to_num(spark:SparkSession):DataFrame={
    import spark.implicits._
    val wnd2 = Window.orderBy("cust_id")
    MYSQLConnection.readMySql(spark,"customs")
      .select($"cust_id",row_number().over(wnd2).alias("uid")).cache()
  }

  val actToNum=udf{
    (str:String)=>{
      str match {
        case "BROWSE"=>1
        case "COLLECT"=>2
        case "BUYCAR"=>3
        case _=>8
      }
    }
  }

  case class UserAction(act:String,act_time:String,cust_id:String,good_id:String,browse:String)

  def als_data(spark:SparkSession): RDD[Rating] ={
    val goodstab:DataFrame = goods_to_num(spark)
    val custstab:DataFrame = user_to_num(spark)
    val txt = spark.sparkContext.textFile("file:///D:/logs/virtualLogs/*.log").cache()
    import spark.implicits._
    // 计算出每个用户对该用户接触过的商品的评分
    val df = txt.map(line=>{
      val arr = line.split(" ")
      UserAction(arr(0),arr(1),arr(2),arr(3),arr(4))
    }).toDF().drop("act_time","browse")
      .select($"cust_id",$"good_id",actToNum($"act").alias("score"))
      .groupBy("cust_id","good_id")
      .agg(sum($"score").alias("score"))
    // 为了防止用户编号或商品编号中含有非数字情况,要对所有的商品和用户编号给一个连续的对应的数字编号后再存到缓存
    // 将df和goodstab、custtab join一下只保留 (gid,uid,score)
    val df2 = df.join(goodstab,Seq("good_id"),"inner")
      .join(custstab,Seq("cust_id"),"inner")
      .select("gid","uid","score")
    //.show(20)
    // 将稀疏表转为 Rating对象集合
    val allData:RDD[Rating] = df2.rdd.map(row=>{
      Rating(
        row.getAs("uid").toString.toInt,
        row.getAs("gid").toString.toInt,
        row.getAs("score").toString.toFloat
      )})
    allData
  }
}

ALS训练,最后需要还原数据(数字=>非数字)

package com.njbdqn.call

import com.njbdqn.datahandler.ALSDataHandler
import com.njbdqn.datahandler.ALSDataHandler.{goods_to_num, user_to_num}
import com.njbdqn.util.{HDFSConnection, MYSQLConnection}
import org.apache.spark.mllib.recommendation.{ALS, Rating}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SparkSession}

object ALSCall {
  def als_call(spark:SparkSession): Unit ={
    val goodstab:DataFrame = goods_to_num(spark)
    val custstab:DataFrame = user_to_num(spark)
    val alldata: RDD[Rating] = ALSDataHandler.als_data(spark).cache()
    // 将获得的Rating集合拆分按照0.2,0.8比例拆成两个集合
   // val Array(train,test) = alldata.randomSplit(Array(0.8,0.2))
    // 使用8成的数据去训练模型
    val model = new ALS().setCheckpointInterval(2).setRank(10).setIterations(20).setLambda(0.01).setImplicitPrefs(false)
      .run(alldata)
    // 对模型进行测试,每个用户推荐前30名商品
    val tj = model.recommendProductsForUsers(30)
    import spark.implicits._
    // (uid,gid,rank)
    val df5 = tj.flatMap{
      case(user:Int,ratings:Array[Rating])=>
        ratings.map{case (rat:Rating)=>(user,rat.product,rat.rating)}
    }.toDF("uid","gid","rank")
      // 还原成(cust_id,good_id,score)
      .join(goodstab,Seq("gid"),"inner")
      .join(custstab,Seq("uid"),"inner")
      .select($"cust_id",$"good_id",$"rank")
   //   .show(false)
    HDFSConnection.writeDataToHDFS("/myshops/dwd_ALS_Iter20",df5)
  }
}

 

 

posted @ 2020-10-25 16:10  PEAR2020  阅读(623)  评论(0编辑  收藏  举报