poj4091The Closest M Points(KD-tree)

kd-tree参考资料:

http://my.oschina.net/keyven/blog/221792

http://blog.csdn.net/lsjseu/article/details/12344443

 

  1 #include <cstdio>
  2 #include <algorithm>
  3 #include <queue>
  4 #include <cstring>
  5 using namespace std;
  6 
  7 const int MAXD = 5, MAXT = 5010;
  8 int n, k, m, cur=1, curd;
  9 
 10 class point{
 11 public:
 12     int a[MAXD];
 13 }pa[MAXT], tp;
 14 inline int sqr(int a){return a*a; }
 15 int dis(point &a, point &b){
 16     int res=0;
 17     for(int i=0; i<k; ++i)
 18         res += sqr(a.a[i]-b.a[i]);
 19     return res;
 20 }
 21 
 22 class cmp{
 23 public:
 24     int operator()(const int &i, const int &j){
 25         return dis(tp,pa[i])<dis(tp,pa[j]);
 26     }
 27 }less_than;
 28 
 29 int comp(const point &i, const point &j){// 比较curd维的大小
 30     return i.a[curd]<j.a[curd];
 31 }    
 32 priority_queue<int,vector<int>,cmp > max_heap;
 33 
 34 class node{
 35 public:
 36     int l,r,dim,lc,rc,idx;
 37 }kdt[MAXT*2];// 线段树,范围:[l,r]
 38 void build_kdt(int lf, int rt, int d){
 39     int m = (lf+rt)>>1;
 40     curd = d;
 41     nth_element(pa+lf, pa+m, pa+rt+1, comp);
 42     node &root = kdt[cur++];
 43     root.l = lf, root.r = rt, root.dim = d,    root.idx = m;
 44     if(lf<m){
 45         root.lc = cur;
 46         build_kdt(lf, m-1, (d+1)%k);
 47     }
 48     if(m<rt){
 49         root.rc = cur;
 50         build_kdt(m+1, rt, (d+1)%k);
 51     }
 52 }
 53 void query(int r){
 54     if(r<=0) return;
 55     node &rt = kdt[r];
 56     // 访问当前结点
 57     if(max_heap.size()<m)
 58         max_heap.push(rt.idx);
 59     else if(less_than(rt.idx, max_heap.top())){
 60         max_heap.pop();
 61         max_heap.push(rt.idx);
 62     }
 63     int fir, sec;
 64     if(tp.a[rt.dim]<pa[rt.idx].a[rt.dim])
 65         fir = rt.lc, sec = rt.rc;
 66     else fir = rt.rc, sec = rt.lc;
 67     query(fir);// 访问目标所在的结点的子树,并进行是否在另一子树的判断
 68     if(max_heap.size()<m || dis(pa[max_heap.top()], tp)>=sqr(tp.a[rt.dim]-pa[rt.idx].a[rt.dim])){
 69         query(sec);
 70     }
 71 }
 72 
 73 void print(){// 从小到大输出最大堆
 74     if(!max_heap.empty()){
 75         int i = max_heap.top(); max_heap.pop();
 76         print();
 77         for(int j=0; j<k; ++j)
 78             printf("%d ", pa[i].a[j]);
 79         printf("\n");
 80     }
 81 }
 82 
 83 int main(){
 84     // freopen("in.txt", "r", stdin);
 85     int t;
 86     while(scanf("%d%d", &n, &k) == 2){
 87         for(int i=0; i<n; ++i){
 88             for(int j=0; j<k; ++j)
 89                 scanf("%d", &pa[i].a[j]);
 90         }
 91         memset(kdt, 0, sizeof(kdt));
 92         cur = 1;
 93         build_kdt(0,n-1,0);
 94         scanf("%d", &t);
 95         while(t--){
 96             for(int i=0; i<k; ++i) scanf("%d", &tp.a[i]);
 97             scanf("%d", &m);
 98             printf("the closest %d points are:\n", m);
 99             while(!max_heap.empty()) max_heap.pop();
100             query(1);
101             print();
102         }
103     }
104     return 0;
105 }

 

posted @ 2016-08-15 11:51  Keep_Going  阅读(371)  评论(0编辑  收藏  举报