Google interview question: k-nearest neighbor (k-d tree)
Question:
You are given information about hotels in a country/city. X and Y coordinates of each hotel are known. You need to suggest the list of nearest hotels to a user who is querying from a particular point (X and Y coordinates of the user are given). Distance is calculated as the straight line distance between the user and the hotel coordinates.
假设数据大小为N,需要寻找k个最近的酒店,最直接的做法就是计算每一家酒店离查询坐标的距离,用一个小堆来记录最近的k个酒店,时间复杂度为O(Nlog(k)),空间复杂度为O(k)。
我们可以通过对数据进行预处理来达到优化查询效率的方法。先对所有酒店的坐标按x坐标排序。对于查询坐标(x,y),给定a(一个猜测的值),通过二分查找区间[x-a,x+a],可以获得所有x坐标在区间内的酒店,再通过上一个方法的小堆记录最近的k个酒店。我们并不十分关心数据预处理的效率,时间复杂度为O(Nlog(N)),对于查询,二分查找的时间复杂度为O(log(N)),假设通过二分查找筛选出的结果有m个,第二步的时间复杂度为O(mlog(k))。因此,对于查询的时间复杂度为O(log(N))+O(mlog(k))。这个做法的问题在于如何确定a的值,不适当的选取会导致(1)m过大,使得查询效率降低;(2)结果不准确,因为可能的k最近酒店在x区间之外。如果不要求结果完全准确地近似做法,其效率远高于上一个方法。
第二个方法将数据对x坐标进行了排序,但y坐标仍然是无序的。我们是否可以进一步优化数据的结构,来提高查询的效率?k-d tree是一种可以考虑的数据结构来解决这类问题。有关k-d tree的概念和原理,以下这篇文章介绍的非常详细。
http://web.stanford.edu/class/cs106l/handouts/assignment-3-kdtree.pdf
首先,我们的问题是二维平面上的点,因此在这题中只需要实现k-d tree的二维情况。先定义点和k-d tree的节点。
class Point2D{ int x; int y; public Point2D(int x, int y){ this.x = x; this.y = y; } } class KdTreeNode{ Point2D val; KdTreeNode left; KdTreeNode right; public KdTreeNode(Point2D p) { this.val = p; } }
对于k-d tree,其定义方式有很多。有些实现中所有数据存放在叶子节点,内部节点只存放划分空间的信息。在这里,只需当做和普通的binary search tree一样处理。
接下来是k-d tree的构建。在对k-d tree有一定了解之后,我们知道对于树的每一层,交替进行这对x和y轴的划分。在这个题目中,k-d tree的构建属于数据的预处理,静态的数据,并不需要考虑k-d tree的插入删除等操作。我们选择以x坐标划分,选择的坐标作为一个节点,将数据划分为两个部分,左边部分所有数据的x坐标都不大于该节点的x坐标,右边部分所有数据的x坐标都不小于该节点的x坐标。然后递归进行,定义根节点为第0层,当层数为偶数时以x坐标划分,当层数为奇数时以y坐标划分。那如何选择划分的节点坐标?为了使k-d tree的查询是高效的,构建的k-d tree需要平衡,因此选择的节点是x/y坐标的中位数。用selection algorithm,可以在O(N)时间内找到该节点,将数据均匀地划分。
public static KdTreeNode constructKdTree(Point2D[] array, int depth, int low, int high){ if(low > high) return null; if(low == high) return new KdTreeNode(array[low]); int mid = low+(high-low)/2; Point2D p = quickSelect(array, mid, low, high, depth%2); KdTreeNode node = new KdTreeNode(p); node.left = constructKdTree(array, depth+1, low, mid-1); node.right = constructKdTree(array, depth+1, mid+1, high); return node; } public static Point2D quickSelect(Point2D[] array, int k, int low, int high, int dimension){ while(low<=high){ int pivotIndex = partition(array, low, high, new Random().nextInt(high-low+1)+low,dimension); if(pivotIndex == k) return array[k]; else if(pivotIndex < k) low = pivotIndex+1; else high = pivotIndex-1; } return null; } public static int partition(Point2D[] array, int low, int high, int pivot, int dimension){ int pivotVal = dimension==0?array[pivot].x:array[pivot].y; swap(array, pivot, high); int index = low; for(int i=low;i<high;i++){ int curVal = dimension==0?array[i].x:array[i].y; if(curVal<pivotVal){ swap(array, index, i); index++; } } swap(array, high, index); return index; } public static void swap(Point2D[] array, int i, int j){ if(i!=j){ Point2D tmp = array[i]; array[i] = array[j]; array[j] = tmp; } }
k-d tree的构建完成,时间复杂度为O(Nlog(N)),空间复杂度为O(N)。接下来是k-d tree的查询操作。我们的问题是要获得最近的酒店列表。在链接的文章中介绍了如何查询最近的坐标和最近的k个坐标。对于这题,我只实现返回最近的坐标。
首先通过查询坐标寻找该坐标所在的划分空间,记录下遍历路径中的离该坐标最近的点。然后根据这个距离r得到中心点为查询坐标,半径为r的搜索空间,再次对kd-tree查询是否存在更近的点。
static Point2D nearestPoint = new Point2D(Integer.MAX_VALUE, Integer.MAX_VALUE); static int min = Integer.MAX_VALUE; public static void queryHelper(KdTreeNode root, Point2D query, int depth){ if(root == null) return; int distance = (query.x-root.val.x)*(query.x-root.val.x)+(query.y-root.val.y)*(query.y-root.val.y); if(distance < min){ min = distance; nearestPoint = root.val; } int curVal = depth%2==0?query.x:query.y; int nodeVal = depth%2==0?query.x:query.y; if(curVal > nodeVal) queryHelper(root.right, query, depth+1); else if(curVal < nodeVal) queryHelper(root.left, query, depth+1); else{ queryHelper(root.right, query, depth+1); queryHelper(root.left, query, depth+1); } } public static void queryNearestHelper(KdTreeNode root, Point2D query, int depth, double xMin, double xMax, double yMin, double yMax){ if(root == null) return; int distance = (query.x-root.val.x)*(query.x-root.val.x)+(query.y-root.val.y)*(query.y-root.val.y); if(distance < min){ min = distance; nearestPoint = root.val; } int curVal = depth%2==0?query.x:query.y; int nodeVal = depth%2==0?query.x:query.y; double rangeMin = depth%2==0?xMin:yMin; double rangeMax = depth%2==0?xMax:yMax; if(curVal > nodeVal){ queryNearestHelper(root.right, query, depth+1, xMin, xMax, yMin, yMax); if(nodeVal > rangeMin) queryNearestHelper(root.left, query, depth+1, xMin, xMax, yMin, yMax); } else if(curVal < nodeVal){ queryNearestHelper(root.left, query, depth+1, xMin, xMax, yMin, yMax); if(nodeVal < rangeMax) queryNearestHelper(root.right, query, depth+1, xMin, xMax, yMin, yMax); } else{ queryHelper(root.right, query, depth+1); queryHelper(root.left, query, depth+1); } } public static void queryNearest(KdTreeNode root, Point2D query){ queryHelper(root, query, 0); double xMin = query.x-Math.sqrt(min), xMax = query.x+Math.sqrt(min), yMin = query.y-Math.sqrt(min), yMax = query.y+Math.sqrt(min); queryNearestHelper(root, query, 0, xMin, xMax, yMin, yMax); }
至此,我们完成了最近酒店的查询,时间复杂度为O(log(N))。