poj4091The Closest M Points(KD-tree)
kd-tree参考资料:
http://my.oschina.net/keyven/blog/221792
http://blog.csdn.net/lsjseu/article/details/12344443
1 #include <cstdio> 2 #include <algorithm> 3 #include <queue> 4 #include <cstring> 5 using namespace std; 6 7 const int MAXD = 5, MAXT = 5010; 8 int n, k, m, cur=1, curd; 9 10 class point{ 11 public: 12 int a[MAXD]; 13 }pa[MAXT], tp; 14 inline int sqr(int a){return a*a; } 15 int dis(point &a, point &b){ 16 int res=0; 17 for(int i=0; i<k; ++i) 18 res += sqr(a.a[i]-b.a[i]); 19 return res; 20 } 21 22 class cmp{ 23 public: 24 int operator()(const int &i, const int &j){ 25 return dis(tp,pa[i])<dis(tp,pa[j]); 26 } 27 }less_than; 28 29 int comp(const point &i, const point &j){// 比较curd维的大小 30 return i.a[curd]<j.a[curd]; 31 } 32 priority_queue<int,vector<int>,cmp > max_heap; 33 34 class node{ 35 public: 36 int l,r,dim,lc,rc,idx; 37 }kdt[MAXT*2];// 线段树,范围:[l,r] 38 void build_kdt(int lf, int rt, int d){ 39 int m = (lf+rt)>>1; 40 curd = d; 41 nth_element(pa+lf, pa+m, pa+rt+1, comp); 42 node &root = kdt[cur++]; 43 root.l = lf, root.r = rt, root.dim = d, root.idx = m; 44 if(lf<m){ 45 root.lc = cur; 46 build_kdt(lf, m-1, (d+1)%k); 47 } 48 if(m<rt){ 49 root.rc = cur; 50 build_kdt(m+1, rt, (d+1)%k); 51 } 52 } 53 void query(int r){ 54 if(r<=0) return; 55 node &rt = kdt[r]; 56 // 访问当前结点 57 if(max_heap.size()<m) 58 max_heap.push(rt.idx); 59 else if(less_than(rt.idx, max_heap.top())){ 60 max_heap.pop(); 61 max_heap.push(rt.idx); 62 } 63 int fir, sec; 64 if(tp.a[rt.dim]<pa[rt.idx].a[rt.dim]) 65 fir = rt.lc, sec = rt.rc; 66 else fir = rt.rc, sec = rt.lc; 67 query(fir);// 访问目标所在的结点的子树,并进行是否在另一子树的判断 68 if(max_heap.size()<m || dis(pa[max_heap.top()], tp)>=sqr(tp.a[rt.dim]-pa[rt.idx].a[rt.dim])){ 69 query(sec); 70 } 71 } 72 73 void print(){// 从小到大输出最大堆 74 if(!max_heap.empty()){ 75 int i = max_heap.top(); max_heap.pop(); 76 print(); 77 for(int j=0; j<k; ++j) 78 printf("%d ", pa[i].a[j]); 79 printf("\n"); 80 } 81 } 82 83 int main(){ 84 // freopen("in.txt", "r", stdin); 85 int t; 86 while(scanf("%d%d", &n, &k) == 2){ 87 for(int i=0; i<n; ++i){ 88 for(int j=0; j<k; ++j) 89 scanf("%d", &pa[i].a[j]); 90 } 91 memset(kdt, 0, sizeof(kdt)); 92 cur = 1; 93 build_kdt(0,n-1,0); 94 scanf("%d", &t); 95 while(t--){ 96 for(int i=0; i<k; ++i) scanf("%d", &tp.a[i]); 97 scanf("%d", &m); 98 printf("the closest %d points are:\n", m); 99 while(!max_heap.empty()) max_heap.pop(); 100 query(1); 101 print(); 102 } 103 } 104 return 0; 105 }