Google interview question: k-nearest neighbor (k-d tree)

Question:

You are given information about hotels in a country/city. X and Y coordinates of each hotel are known. You need to suggest the list of nearest hotels to a user who is querying from a particular point (X and Y coordinates of the user are given). Distance is calculated as the straight line distance between the user and the hotel coordinates.

假设数据大小为N,需要寻找k个最近的酒店,最直接的做法就是计算每一家酒店离查询坐标的距离,用一个小堆来记录最近的k个酒店,时间复杂度为O(Nlog(k)),空间复杂度为O(k)。

我们可以通过对数据进行预处理来达到优化查询效率的方法。先对所有酒店的坐标按x坐标排序。对于查询坐标(x,y),给定a(一个猜测的值),通过二分查找区间[x-a,x+a],可以获得所有x坐标在区间内的酒店,再通过上一个方法的小堆记录最近的k个酒店。我们并不十分关心数据预处理的效率,时间复杂度为O(Nlog(N)),对于查询,二分查找的时间复杂度为O(log(N)),假设通过二分查找筛选出的结果有m个,第二步的时间复杂度为O(mlog(k))。因此,对于查询的时间复杂度为O(log(N))+O(mlog(k))。这个做法的问题在于如何确定a的值,不适当的选取会导致(1)m过大,使得查询效率降低;(2)结果不准确,因为可能的k最近酒店在x区间之外。如果不要求结果完全准确地近似做法,其效率远高于上一个方法。

第二个方法将数据对x坐标进行了排序,但y坐标仍然是无序的。我们是否可以进一步优化数据的结构,来提高查询的效率?k-d tree是一种可以考虑的数据结构来解决这类问题。有关k-d tree的概念和原理,以下这篇文章介绍的非常详细。

http://web.stanford.edu/class/cs106l/handouts/assignment-3-kdtree.pdf

首先,我们的问题是二维平面上的点,因此在这题中只需要实现k-d tree的二维情况。先定义点和k-d tree的节点。

class Point2D{
    int x;
    int y;
    public Point2D(int x, int y){
        this.x = x;
        this.y = y;
    }
}
class KdTreeNode{
    Point2D val;
    KdTreeNode left;
    KdTreeNode right;
    public KdTreeNode(Point2D p) {
        this.val = p;
    }
}

对于k-d tree,其定义方式有很多。有些实现中所有数据存放在叶子节点,内部节点只存放划分空间的信息。在这里,只需当做和普通的binary search tree一样处理。

接下来是k-d tree的构建。在对k-d tree有一定了解之后,我们知道对于树的每一层,交替进行这对x和y轴的划分。在这个题目中,k-d tree的构建属于数据的预处理,静态的数据,并不需要考虑k-d tree的插入删除等操作。我们选择以x坐标划分,选择的坐标作为一个节点,将数据划分为两个部分,左边部分所有数据的x坐标都不大于该节点的x坐标,右边部分所有数据的x坐标都不小于该节点的x坐标。然后递归进行,定义根节点为第0层,当层数为偶数时以x坐标划分,当层数为奇数时以y坐标划分。那如何选择划分的节点坐标?为了使k-d tree的查询是高效的,构建的k-d tree需要平衡,因此选择的节点是x/y坐标的中位数。用selection algorithm,可以在O(N)时间内找到该节点,将数据均匀地划分。

    public static KdTreeNode constructKdTree(Point2D[] array, int depth, int low, int high){
        if(low > high) return null;
        if(low == high) return new KdTreeNode(array[low]);
        int mid = low+(high-low)/2;
        Point2D p = quickSelect(array, mid, low, high, depth%2);
        KdTreeNode node = new KdTreeNode(p);
        node.left = constructKdTree(array, depth+1, low, mid-1);
        node.right = constructKdTree(array, depth+1, mid+1, high);
        return node;
    }
    public static Point2D quickSelect(Point2D[] array, int k,  int low, int high, int dimension){
        while(low<=high){
            int pivotIndex = partition(array, low, high, new Random().nextInt(high-low+1)+low,dimension);
            if(pivotIndex == k) return array[k];
            else if(pivotIndex < k) low = pivotIndex+1;
            else high = pivotIndex-1;
        }
        return null;
    }
    public static int partition(Point2D[] array, int low, int high, int pivot, int dimension){
        int pivotVal = dimension==0?array[pivot].x:array[pivot].y;
        swap(array, pivot, high);
        int index = low;
        for(int i=low;i<high;i++){
            int curVal = dimension==0?array[i].x:array[i].y;
            if(curVal<pivotVal){
                swap(array, index, i);
                index++;
            }
        }
        swap(array, high, index);
        return index;
    }
    public static void swap(Point2D[] array, int i, int j){
        if(i!=j){
            Point2D tmp = array[i];
            array[i] = array[j];
            array[j] = tmp;
        }
    }

k-d tree的构建完成,时间复杂度为O(Nlog(N)),空间复杂度为O(N)。接下来是k-d tree的查询操作。我们的问题是要获得最近的酒店列表。在链接的文章中介绍了如何查询最近的坐标和最近的k个坐标。对于这题,我只实现返回最近的坐标。

首先通过查询坐标寻找该坐标所在的划分空间,记录下遍历路径中的离该坐标最近的点。然后根据这个距离r得到中心点为查询坐标,半径为r的搜索空间,再次对kd-tree查询是否存在更近的点。

    static Point2D nearestPoint = new Point2D(Integer.MAX_VALUE, Integer.MAX_VALUE);
    static int min = Integer.MAX_VALUE;
    public static void queryHelper(KdTreeNode root, Point2D query, int depth){
        if(root == null) return;
        int distance = (query.x-root.val.x)*(query.x-root.val.x)+(query.y-root.val.y)*(query.y-root.val.y);
        if(distance < min){
            min = distance;
            nearestPoint = root.val;
        }
        int curVal = depth%2==0?query.x:query.y;
        int nodeVal = depth%2==0?query.x:query.y;
        if(curVal > nodeVal) queryHelper(root.right, query, depth+1);
        else if(curVal < nodeVal) queryHelper(root.left, query, depth+1);
        else{
            queryHelper(root.right, query, depth+1);
            queryHelper(root.left, query, depth+1);
        }
    }
    public static void queryNearestHelper(KdTreeNode root, Point2D query, int depth, double xMin, double xMax, double yMin, double yMax){
        if(root == null) return;
        int distance = (query.x-root.val.x)*(query.x-root.val.x)+(query.y-root.val.y)*(query.y-root.val.y);
        if(distance < min){
            min = distance;
            nearestPoint = root.val;
        }
        int curVal = depth%2==0?query.x:query.y;
        int nodeVal = depth%2==0?query.x:query.y;
        double rangeMin = depth%2==0?xMin:yMin;
        double rangeMax = depth%2==0?xMax:yMax;
        if(curVal > nodeVal){
            queryNearestHelper(root.right, query, depth+1, xMin, xMax, yMin, yMax);
            if(nodeVal > rangeMin) queryNearestHelper(root.left, query, depth+1, xMin, xMax, yMin, yMax);
        }
        else if(curVal < nodeVal){
            queryNearestHelper(root.left, query, depth+1, xMin, xMax, yMin, yMax);
            if(nodeVal < rangeMax)  queryNearestHelper(root.right, query, depth+1, xMin, xMax, yMin, yMax);
        }
        else{
            queryHelper(root.right, query, depth+1);
            queryHelper(root.left, query, depth+1);
        }
    }
    public static void queryNearest(KdTreeNode root, Point2D query){
        queryHelper(root, query, 0);
        double xMin = query.x-Math.sqrt(min), xMax = query.x+Math.sqrt(min), yMin = query.y-Math.sqrt(min), yMax = query.y+Math.sqrt(min);
        queryNearestHelper(root, query, 0, xMin, xMax, yMin, yMax);
    }

至此,我们完成了最近酒店的查询,时间复杂度为O(log(N))。

 

posted @ 2015-07-01 03:21  dshao  阅读(347)  评论(0编辑  收藏  举报