minhash pyspark 源码分析——hash join table是关键

从下面分析可以看出,是先做了hash计算,然后使用hash join table来讲hash值相等的数据合并在一起。然后再使用udf计算距离,最后再filter出满足阈值的数据:

参考:https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
  /**
   * Join two datasets to approximately find all pairs of rows whose distance are smaller than
   * the threshold. If the [[outputCol]] is missing, the method will transform the data; if the
   * [[outputCol]] exists, it will use the [[outputCol]]. This allows caching of the transformed
   * data when necessary.
   *
   * @param datasetA One of the datasets to join.
   * @param datasetB Another dataset to join.
   * @param threshold The threshold for the distance of row pairs.
   * @param distCol Output column for storing the distance between each pair of rows.
   * @return A joined dataset containing pairs of rows. The original rows are in columns
   *         "datasetA" and "datasetB", and a column "distCol" is added to show the distance
   *         between each pair.
   */
  def approxSimilarityJoin(
      datasetA: Dataset[_],
      datasetB: Dataset[_],
      threshold: Double,
      distCol: String): Dataset[_] = {

    val leftColName = "datasetA"
    val rightColName = "datasetB"
    val explodeCols = Seq("entry", "hashValue")
    val explodedA = processDataset(datasetA, leftColName, explodeCols)

    // If this is a self join, we need to recreate the inputCol of datasetB to avoid ambiguity.
    // TODO: Remove recreateCol logic once SPARK-17154 is resolved.
    val explodedB = if (datasetA != datasetB) {
      processDataset(datasetB, rightColName, explodeCols)
    } else {
      val recreatedB = recreateCol(datasetB, $(inputCol), s"${$(inputCol)}#${Random.nextString(5)}")
      processDataset(recreatedB, rightColName, explodeCols)
    }

    // Do a hash join on where the exploded hash values are equal.
    val joinedDataset = explodedA.join(explodedB, explodeCols)
      .drop(explodeCols: _*).distinct()

    // Add a new column to store the distance of the two rows.
    val distUDF = udf((x: Vector, y: Vector) => keyDistance(x, y), DataTypes.DoubleType)
    val joinedDatasetWithDist = joinedDataset.select(col("*"),
      distUDF(col(s"$leftColName.${$(inputCol)}"), col(s"$rightColName.${$(inputCol)}")).as(distCol)
    )

    // Filter the joined datasets where the distance are smaller than the threshold.
    joinedDatasetWithDist.filter(col(distCol) < threshold)
  }

 

补充:

sql join 算法 时间复杂度

参考

stackoverflow

笔记

sql语句如下:

SELECT  T1.name, T2.date
FROM    T1, T2
WHERE   T1.id=T2.id
        AND T1.color='red'
        AND T2.type='CAR'

 

假设T1有m行,T2有n行,那么,普通情况下,应该要遍历T1的每一行的id(m),然后在遍历T2(n)中找出T2.id = T1.id的行进行join。时间复杂度应该是O(m*n)

如果没有索引的话,engine会选择hash join或者merge join进行优化。

hash join是这样的:

  1. 选择被哈希的表,通常是小一点的表。让我们愉快地假定是T1更小吧。
  2. T1所有的记录都被遍历。如果记录符合color=’red’,这条记录就会进去哈希表,以id为key,以name为value。
  3. T2所有的记录被遍历。如果记录符合type=’CAR’,使用这条记录的id去搜索哈希表,所有命中的记录的name的值,都被返回,还带上了当前记录的date的值,这样就可以把两者join起来了。

时间复杂度O(n+m),实现hash表是O(n),hash表查找是O(m),直接将其相加。

merge join是这样的:

1.复制T1(id, name),根据id排序。
2.复制T2(id, date),根据id排序。
3.两个指针指向两个表的最小值。

    >1 2<
     2 3
     2 4
     3 5

4.在循环中比较指针,如果match,就返回记录。如果不match,指向较小值的指针指向下一个记录。

>1  2<  - 不match, 左指针小,左指针++
 2  3
 2  4
 3  5

 1  2<  - match, 返回记录,两个指针都++
>2  3
 2  4
 3  5

 1  2  - match, 返回记录,两个指针都++
 2  3< 
 2  4
>3  5

 1  2 - 左指针越界,查询结束。
 2  3
 2  4<
 3  5
>

 

时间复杂度O(n*log(n)+m*log(m))。排序算法的复杂度分别是O(n*log(n))和O(m*log(m)),直接将两者相加。

在这种情况下,使查询更加复杂反而可以加快速度,因为更少的行需要经受join-level的测试?

当然了。

如果原来的query没有where语句,如

SELECT  T1.name, T2.date
FROM    T1, T2

是更简单的,但是会返回更多的结果并运行更长的时间。

  

 

hash函数的补充:

可以看到 hashFunction 涉及到indices 字段下表的计算。另外的distance计算使用了jaccard相似度。

from:https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala

/**
 * :: Experimental ::
 *
 * Model produced by [[MinHashLSH]], where multiple hash functions are stored. Each hash function
 * is picked from the following family of hash functions, where a_i and b_i are randomly chosen
 * integers less than prime:
 *    `h_i(x) = ((x \cdot a_i + b_i) \mod prime)`
 *
 * This hash family is approximately min-wise independent according to the reference.
 *
 * Reference:
 * Tom Bohman, Colin Cooper, and Alan Frieze. "Min-wise independent linear permutations."
 * Electronic Journal of Combinatorics 7 (2000): R26.
 *
 * @param randCoefficients Pairs of random coefficients. Each pair is used by one hash function.
 */
@Experimental
@Since("2.1.0")
class MinHashLSHModel private[ml](
    override val uid: String,
    private[ml] val randCoefficients: Array[(Int, Int)])
  extends LSHModel[MinHashLSHModel] {

  /** @group setParam */
  @Since("2.4.0")
  override def setInputCol(value: String): this.type = super.set(inputCol, value)

  /** @group setParam */
  @Since("2.4.0")
  override def setOutputCol(value: String): this.type = super.set(outputCol, value)

  @Since("2.1.0")
  override protected[ml] def hashFunction(elems: Vector): Array[Vector] = {
    require(elems.numNonzeros > 0, "Must have at least 1 non zero entry.")
    val elemsList = elems.toSparse.indices.toList
    val hashValues = randCoefficients.map { case (a, b) =>
      elemsList.map { elem: Int =>
        ((1L + elem) * a + b) % MinHashLSH.HASH_PRIME
      }.min.toDouble
    }
    // TODO: Output vectors of dimension numHashFunctions in SPARK-18450
    hashValues.map(Vectors.dense(_))
  }

  @Since("2.1.0")
  override protected[ml] def keyDistance(x: Vector, y: Vector): Double = {
    val xSet = x.toSparse.indices.toSet
    val ySet = y.toSparse.indices.toSet
    val intersectionSize = xSet.intersect(ySet).size.toDouble
    val unionSize = xSet.size + ySet.size - intersectionSize
    assert(unionSize > 0, "The union of two input sets must have at least 1 elements")
    1 - intersectionSize / unionSize
  }

  @Since("2.1.0")
  override protected[ml] def hashDistance(x: Seq[Vector], y: Seq[Vector]): Double = {
    // Since it's generated by hashing, it will be a pair of dense vectors.
    // TODO: This hashDistance function requires more discussion in SPARK-18454
    x.zip(y).map(vectorPair =>
      vectorPair._1.toArray.zip(vectorPair._2.toArray).count(pair => pair._1 != pair._2)
    ).min
  }

  @Since("2.1.0")
  override def copy(extra: ParamMap): MinHashLSHModel = {
    val copied = new MinHashLSHModel(uid, randCoefficients).setParent(parent)
    copyValues(copied, extra)
  }

  @Since("2.1.0")
  override def write: MLWriter = new MinHashLSHModel.MinHashLSHModelWriter(this)
}

  

posted @ 2019-07-08 15:54  bonelee  阅读(695)  评论(0编辑  收藏  举报