收藏:①极市开发DeepLearning ②Git使用

推荐系统-双塔召回随机负采样

// 二分查找
  def fetchBinarySearch(trainItems: Array[(String, Double)], target: Double): String = {
//    val trainItems = Array(("1", 0), ("2", 1), ("3", 3), ("4", 4), ("5", 6))
//    val target = 6.0000000000018032
    if (trainItems.length == 0) {
      ""
    } else {
      var left = 0
      var right = trainItems.length - 1
      while(left < right) {
        val mid = ((left + right)/2).toInt
        if (trainItems(mid)._2 < target) {
          left = mid + 1
        } else {
          right = mid
        }
      }
      trainItems(left)._1
    }
  }

  // 获取采样负样本用户集
  def fetchFullSampleItemsUdf(trainItems: Array[(String, Double)], trainItemsSize: Int, negNum: Int): UserDefinedFunction = udf(
    (app_info: String) => {
      val sampleItems = app_info.split(",").map(t => t.split(":")(0)).toBuffer
      val sampleItemsSet = scala.collection.mutable.Set[String]() ++ sampleItems.toSet
      val posNum = sampleItems.size
      var tmpNegNum = posNum*negNum
//      val trainItems = Array(("1", 0.1), ("2", 0.2), ("3", 0.3), ("4", 0.4))
      val probabilities = DenseVector(trainItems.map(_._2))
      while(tmpNegNum > 0) {
//        // 随机负采样
//        val randomIndex = (new Random).nextInt(trainItemsSize)
//        val negItem = trainItems(randomIndex)._1
        // 带权负采样(二分查找)
        val randomTarget = (new Random).nextDouble()
        val negItem = fetchBinarySearch(trainItems, randomTarget)
//        // 带权负采样(调用接口函数)
//        val randomIndex =  new Multinomial(probabilities).sample(1).head
//        val negItem = trainItems(randomIndex)._1
        if (!sampleItemsSet.contains(negItem)) {
          sampleItems.append(negItem)
          tmpNegNum = tmpNegNum - 1
        }
      }
      sampleItems.zipWithIndex.map{
        case (item, i) =>
          val label = if (i < posNum) 1 else 0
          (item, label)
      }
    }
  )




// 样本数据拼接
  def fetchSampleData(spark: SparkSession, day: String, part: String, negNum: Int): DataFrame = {
//    val part = "0"
    val targetData = fetchTargetData(spark, day, part)
    val userMap = {
      targetData.select("user_id").dropDuplicates("user_id").rdd
        .map {row =>
          val user_id = row.getAs[String]("user_id")
          (user_id, "1")
        }.collect().toMap
    }
//    val trainItems = fetchItemSampleData(spark, day).dropDuplicates("appid").rdd.map{
//      row => row.getAs[String]("appid")
//    }.collect()

    val win = Window.partitionBy("day")
    val win2 = Window.partitionBy("day").orderBy("pv")
    val win3 = Window.partitionBy("day").orderBy("rank")
    val trainItems = {
      fetchItemSampleData(spark, day).groupBy("day", "appid").agg(expr("power(count(user_id), 0.75) as pv"))
        .withColumn("pv_sum", sum("pv").over(win))
        .withColumn("fw", col("pv")/col("pv_sum"))
        .withColumn("rank", row_number().over(win2))
        .withColumn("fp", sum("fw").over(win3)).rdd        //相同值累计求和有问题
        .map{row =>
          val appid = row.getAs[String]("appid")
          val fp = row.getAs[Double]("fp")
          val pv = row.getAs[Double]("pv")
          val fw = row.getAs[Double]("fw")
          (appid, fp)
        }.collect()
    }.sortBy(_._2)
//    trainItems.reverse.take(10)
    val trainItemsSize = trainItems.length
//    targetData.
//      withColumn("fullSampleItems", fetchFullSampleItemsUdf(trainItems, trainItemsSize, 5)(col("app_info"))).
//      withColumn("fullSampleItems", explode(col("fullSampleItems"))).
//      withColumn("item_id", col("fullSampleItems").getField("_1")).
//      withColumn("target", col("fullSampleItems").getField("_2")).
//      groupBy("item_id").agg(expr("count(if(target == '1', user_id, null)) as pos_pv"),
//      expr("count(if(target == '0', user_id, null)) as neg_pv")).orderBy(desc("pos_pv")).
//      show(10, false)
    val userFeatures = fetchUserFeatures(spark, day, userMap)
    val itemFeatures = fetchItemFeatures(spark, day)
    val sampleData = {
      targetData.join(userFeatures, Seq("user_id"), "left")
        .withColumn("fullSampleItems", fetchFullSampleItemsUdf(trainItems, trainItemsSize, negNum)(col("app_info")))
        .withColumn("fullSampleItems", explode(col("fullSampleItems")))
        .withColumn("item_id", col("fullSampleItems").getField("_1"))
        .withColumn("target", col("fullSampleItems").getField("_2"))
        .join(broadcast(itemFeatures), Seq("item_id"), "left")
    }
    sampleData
  }

 

posted @ 2024-01-17 17:38  WSX_1994  阅读(6)  评论(0编辑  收藏  举报