kd-Tree 【专题@AbandonZHANG】

刚开始学习,介绍先搁着~等理解透彻了再来写~~~   我是学习的mzry1992(UESTC_Izayoi ~)------http://www.mzry1992.com/blog/miao/kd%E6%A0%91.html 先去看mzry1992大牛博客里的讲解吧。。。   再附两篇论文:(看英文看得好爽。。。~@.@) 《An intoductory tutorial on kd-trees》  ★(里面就介绍了kd-tree和nearest neighbour algorithm(最近邻算法)、Q nearest neighbour(Q近邻) 《Range Searching Using Kd-Tree.》    kd-tree入门题:     HDOJ 2966 In case of failure (最近邻,模板~)   查找平面点最近点的距离(此题中是距离的平方)
/*
    HDOJ 2966
    KD-Tree模板
*/
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#define MID(x, y) ( (x + y)>>1 )

using namespace std;
typedef long long LL;

//KD-Tree模板
const int N=100005;
LL res;
struct Point
{
    int x, y;        //点是二维的,此时是2D-Tree
};

LL dist2(const Point &a, const Point &b)            //距离的平方
{
    return LL(a.x - b.x) * LL(a.x - b.x) + LL(a.y - b.y) * LL(a.y - b.y);
}

bool cmpX(const Point &a, const Point &b)
{
    return a.x < b.x;
}
bool cmpY(const Point &a, const Point &b)
{
    return a.y < b.y;
}

struct KDTree        //很崇拜这种销魂的建树方法啊~0.0~很抽象很强大------p数组已经代表了KD-Tree了,神马左右子树全省了,OOOOOrrz!!!……
{
    Point p[N];        //空间内的点
    int Div[N];        //记录区间是按什么方式划分(分割线平行于x轴还是y轴, ==1平行y轴切;==0平行x轴切)

    void build(int l, int r)            //记得先把p备份一下。
    {
        if (l > r)    return;
        int mid=MID(l, r);
        int minX, minY, maxX, maxY;
        minX = min_element(p + l, p + r + 1, cmpX)->x;
        minY = min_element(p + l, p + r + 1, cmpY)->y;
        maxX = max_element(p + l, p + r + 1, cmpX)->x;
        maxY = max_element(p + l, p + r + 1, cmpY)->y;
        Div[mid] = (maxX - minX >= maxY - minY);
        nth_element(p + l, p + mid, p + r + 1, Div[mid] ? cmpX : cmpY);
        build(l, mid - 1);
        build(mid+1, r);
    }

    void find(int l, int r, Point a)                //查找最近点的平方距离
    {
        if (l > r)    return;
        int mid = MID(l, r);
        LL dist = dist2(a, p[mid]);
        if (dist > 0)   //如果有重点不能这么判断
            res = min(res, dist);
        LL d = Div[mid] ? (a.x - p[mid].x) : (a.y - p[mid].y);
        int l1, l2, r1, r2;
        l1 = l , l2 = mid + 1;
        r1 = mid - 1, r2 = r;
        if (d > 0)
            swap(l1, l2), swap(r1, r2);
        find(l1, r1, a);
        if (d * d < res)
            find(l2, r2, a);
    }
};
  HDOJ 4347 The Closest M Points (Q近邻)   与上题不同的是,一是k维(这个好处理~),二是求最近的m个点而不单是最近点了。这个也好处理~递归查找时处理的时候采取如下策略:如果当前找到的点小于k个,那么两个区间都要处理。。否则根据当前找到的第k个点决定是否去另外一个区间,如果目标点到分界线的距离大于等于已经找到的第k远的点,那么就不用查找另一个分界了。。。更新答案可以用一个大小为k的堆去维护(一个最大堆,一旦超过k个点就把最大的扔掉)。。。
#include 
 #include 
 #include 
 #include 
 #include 
 #include 
 #include 
 #include 
 #include 
 #include 
 #include 
 #include 
 #define MID(x,y) ( (x + y)>>1 )
 
 using namespace std;
 typedef long long LL;
 
 //KD-Tree模板
 const int N=50005;
 
 struct Point
 {
     int x[5];
     LL dis;
     Point()
     {
         for (int i = 0; i < 5; i++)
             x[i] = 0;
         dis = 9223372036854775807LL;
     }
     friend bool operator < (const Point &a, const Point &b)
     {
         return a.dis < b.dis;
     }
 };
 priority_queue  > res;
 
 LL dist2(const Point &a, const Point &b, int k)            //距离的平方,开根号很耗时,能不开就不开
 {
     LL ans = 0;
     for (int i = 0; i < k; i++)         //一开始这儿写的i < 5,WA了N次。。。原本以为只要初始值设0就无所谓,
         ans += (a.x[i] - b.x[i]) * (a.x[i] - b.x[i]);   //但发现Point有个全局变量,在多case下会出错。。。
     return ans;
 }
 
 int ddiv;
 bool cmpX(const Point &a, const Point &b)
 {
     return a.x[ddiv] < b.x[ddiv];
 }
 
 struct KDTree        //很崇拜这种销魂的建树方法啊~0.0~很抽象很强大------p数组已经代表了KD-Tree了,神马左右子树全省了,OOOOOrrz!!!……
 {
     Point p[N];        //空间内的点
     int Div[N];        //记录区间是按什么方式划分(分割线平行于x轴还是y轴, ==1平行y轴切;==0平行x轴切)
     int k;          //维数
     int m;          //近邻
 
     int getDiv(int l, int r)        //寻找区间跨度最大的划分方式
     {
         map  ms;
         int minx[5],maxx[5];
         for (int i = 0; i < k; i++)
         {
             ddiv = i;
             minx[i] = min_element(p + l, p + r + 1, cmpX)->x[i];
             maxx[i] = max_element(p + l, p + r + 1, cmpX)->x[i];
             ms[maxx[i] - minx[i]] = i;
         }
         map ::iterator pm = ms.end();
         pm--;
         return pm->second;
     }
 
     void build(int l, int r)            //记得先把p备份一下。
     {
         if (l > r)    return;
         int mid = MID(l,r);
         Div[mid] = getDiv(l,r);
         ddiv = Div[mid];
         nth_element(p + l, p + mid, p + r + 1, cmpX);
         build(l, mid - 1);
         build(mid + 1, r);
     }
 
     void findk(int l, int r, Point a)                //k(m)近邻,查找k近点的平方距离
     {
         if (l > r)    return;
         int mid = MID(l,r);
         LL dist = dist2(a, p[mid], k);
         if (dist >= 0)
         {
             p[mid].dis = dist;
             res.push(p[mid]);
             while ((int)res.size() > m)
                 res.pop();
         }
         LL d = a.x[Div[mid]] - p[mid].x[Div[mid]];
         int l1, l2, r1, r2;
         l1 = l , l2 = mid + 1;
         r1 = mid - 1, r2 = r;
         if (d > 0)
             swap(l1, l2), swap(r1, r2);
         findk(l1, r1, a);
         if ((int)res.size() < m || d*d < res.top().dis )
             findk(l2, r2, a);
     }
 };
 
 Point pp[N];
 KDTree kd;
 
 int main()
 {
 //    freopen("test.txt","r+",stdin);
 //    freopen("ans.txt","w+",stdout);
 
     int n;
     while(scanf("%d%d", &n, &kd.k)!=EOF)
     {
         for (int i = 0; i < n; i++)
             for (int j = 0; j < kd.k; j++)
             {
                 scanf("%d", &pp[i].x[j]);
                 kd.p[i] = pp[i];
             }
         kd.build(0, n - 1);
         int t;
         scanf("%d", &t);
         while(t--)
         {
             Point a;
             for (int i = 0; i < kd.k; i++)
                 scanf("%d", &a.x[i]);
             scanf("%d", &kd.m);
             kd.findk(0, n - 1, a);
             printf("the closest %d points are:\n", kd.m);
             Point ans[11];
             for (int i = 0; !res.empty(); i++)
             {
                 ans[i] = res.top();
                 res.pop();
             }
             for (int i = kd.m - 1; i >= 0; i--)
             {
                 for (int j = 0; j < kd.k - 1; j++)
                     printf("%d ", ans[i].x[j]);
                 printf("%d\n", ans[i].x[kd.k - 1]);
             }
         }
 
     }
     return 0;
 }
 
posted @ 2012-09-21 09:27  AbandonZHANG  阅读(181)  评论(0编辑  收藏  举报