minhash pyspark 源码分析——hash join table是关键
从下面分析可以看出,是先做了hash计算,然后使用hash join table来讲hash值相等的数据合并在一起。然后再使用udf计算距离,最后再filter出满足阈值的数据:
1 | 参考:<a href = "https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala" rel = "noopener nofollow" >https: / / github.com / apache / spark / blob / master / mllib / src / main / scala / org / apache / spark / ml / feature / LSH.scala< / a> |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 | / * * * Join two datasets to approximately find all pairs of rows whose distance are smaller than * the threshold. If the [[outputCol]] is missing, the method will transform the data; if the * [[outputCol]] exists, it will use the [[outputCol]]. This allows caching of the transformed * data when necessary. * * @param datasetA One of the datasets to join. * @param datasetB Another dataset to join. * @param threshold The threshold for the distance of row pairs. * @param distCol Output column for storing the distance between each pair of rows. * @ return A joined dataset containing pairs of rows. The original rows are in columns * "datasetA" and "datasetB" , and a column "distCol" is added to show the distance * between each pair. * / def approxSimilarityJoin( datasetA: Dataset[_], datasetB: Dataset[_], threshold: Double, distCol: String): Dataset[_] = { val leftColName = "datasetA" val rightColName = "datasetB" val explodeCols = Seq( "entry" , "hashValue" ) val explodedA = processDataset(datasetA, leftColName, explodeCols) / / If this is a self join, we need to recreate the inputCol of datasetB to avoid ambiguity. / / TODO: Remove recreateCol logic once SPARK - 17154 is resolved. val explodedB = if (datasetA ! = datasetB) { processDataset(datasetB, rightColName, explodeCols) } else { val recreatedB = recreateCol(datasetB, $(inputCol), s "${$(inputCol)}#${Random.nextString(5)}" ) processDataset(recreatedB, rightColName, explodeCols) } / / Do a hash join on where the exploded hash values are equal. val joinedDataset = explodedA.join(explodedB, explodeCols) .drop(explodeCols: _ * ).distinct() / / Add a new column to store the distance of the two rows. val distUDF = udf((x: Vector, y: Vector) = > keyDistance(x, y), DataTypes.DoubleType) val joinedDatasetWithDist = joinedDataset.select(col( "*" ), distUDF(col(s "$leftColName.${$(inputCol)}" ), col(s "$rightColName.${$(inputCol)}" )).as(distCol) ) / / Filter the joined datasets where the distance are smaller than the threshold. joinedDatasetWithDist. filter (col(distCol) < threshold) } |
补充:
sql join 算法 时间复杂度
参考
笔记
sql语句如下:
SELECT T1.name, T2.date
FROM T1, T2
WHERE T1.id=T2.id
AND T1.color='red'
AND T2.type='CAR'
假设T1有m行,T2有n行,那么,普通情况下,应该要遍历T1的每一行的id(m),然后在遍历T2(n)中找出T2.id = T1.id的行进行join。时间复杂度应该是O(m*n)
如果没有索引的话,engine会选择hash join或者merge join进行优化。
hash join是这样的:
- 选择被哈希的表,通常是小一点的表。让我们愉快地假定是T1更小吧。
- T1所有的记录都被遍历。如果记录符合color=’red’,这条记录就会进去哈希表,以id为key,以name为value。
- T2所有的记录被遍历。如果记录符合type=’CAR’,使用这条记录的id去搜索哈希表,所有命中的记录的name的值,都被返回,还带上了当前记录的date的值,这样就可以把两者join起来了。
时间复杂度O(n+m),实现hash表是O(n),hash表查找是O(m),直接将其相加。
merge join是这样的:
1.复制T1(id, name),根据id排序。
2.复制T2(id, date),根据id排序。
3.两个指针指向两个表的最小值。
>1 2<
2 3
2 4
3 5
4.在循环中比较指针,如果match,就返回记录。如果不match,指向较小值的指针指向下一个记录。
>1 2< - 不match, 左指针小,左指针++
2 3
2 4
3 5
1 2< - match, 返回记录,两个指针都++
>2 3
2 4
3 5
1 2 - match, 返回记录,两个指针都++
2 3<
2 4
>3 5
1 2 - 左指针越界,查询结束。
2 3
2 4<
3 5
>
时间复杂度O(n*log(n)+m*log(m))。排序算法的复杂度分别是O(n*log(n))和O(m*log(m)),直接将两者相加。
在这种情况下,使查询更加复杂反而可以加快速度,因为更少的行需要经受join-level的测试?
当然了。
如果原来的query没有where语句,如
SELECT T1.name, T2.date
FROM T1, T2
是更简单的,但是会返回更多的结果并运行更长的时间。
hash函数的补充:
1 | 可以看到 hashFunction 涉及到indices 字段下表的计算。另外的distance计算使用了jaccard相似度。 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | / * * * :: Experimental :: * * Model produced by [[MinHashLSH]], where multiple hash functions are stored. Each hash function * is picked from the following family of hash functions, where a_i and b_i are randomly chosen * integers less than prime: * `h_i(x) = ((x \cdot a_i + b_i) \mod prime)` * * This hash family is approximately min - wise independent according to the reference. * * Reference: * Tom Bohman, Colin Cooper, and Alan Frieze. "Min-wise independent linear permutations." * Electronic Journal of Combinatorics 7 ( 2000 ): R26. * * @param randCoefficients Pairs of random coefficients. Each pair is used by one hash function. * / @Experimental @Since ( "2.1.0" ) class MinHashLSHModel private[ml]( override val uid: String, private[ml] val randCoefficients: Array[( Int , Int )]) extends LSHModel[MinHashLSHModel] { / * * @group setParam * / @Since ( "2.4.0" ) override def setInputCol(value: String): this. type = super . set (inputCol, value) / * * @group setParam * / @Since ( "2.4.0" ) override def setOutputCol(value: String): this. type = super . set (outputCol, value) @Since ( "2.1.0" ) override protected[ml] def hashFunction(elems: Vector): Array[Vector] = { require(elems.numNonzeros > 0 , "Must have at least 1 non zero entry." ) val elemsList = elems.toSparse.indices.toList val hashValues = randCoefficients. map { case (a, b) = > elemsList. map { elem: Int = > (( 1L + elem) * a + b) % MinHashLSH.HASH_PRIME }. min .toDouble } / / TODO: Output vectors of dimension numHashFunctions in SPARK - 18450 hashValues. map (Vectors.dense(_)) } @Since ( "2.1.0" ) override protected[ml] def keyDistance(x: Vector, y: Vector): Double = { val xSet = x.toSparse.indices.toSet val ySet = y.toSparse.indices.toSet val intersectionSize = xSet.intersect(ySet).size.toDouble val unionSize = xSet.size + ySet.size - intersectionSize assert (unionSize > 0 , "The union of two input sets must have at least 1 elements" ) 1 - intersectionSize / unionSize } @Since ( "2.1.0" ) override protected[ml] def hashDistance(x: Seq[Vector], y: Seq[Vector]): Double = { / / Since it's generated by hashing, it will be a pair of dense vectors. / / TODO: This hashDistance function requires more discussion in SPARK - 18454 x. zip (y). map (vectorPair = > vectorPair._1.toArray. zip (vectorPair._2.toArray).count(pair = > pair._1 ! = pair._2) ). min } @Since ( "2.1.0" ) override def copy(extra: ParamMap): MinHashLSHModel = { val copied = new MinHashLSHModel(uid, randCoefficients).setParent(parent) copyValues(copied, extra) } @Since ( "2.1.0" ) override def write: MLWriter = new MinHashLSHModel.MinHashLSHModelWriter(this) } |
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 记一次.NET内存居高不下排查解决与启示
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· DeepSeek 开源周回顾「GitHub 热点速览」
2017-07-08 赴美生子入境经验汇总