求与询问点欧几里德距离前m小的点
其实就是在kdtree询问的时候用优先队列维护一下就好了
好久没写kdtree练一练,注意这道题是多测
1 #include<bits/stdc++.h> 2 3 using namespace std; 4 const int inf=1e4+5; 5 int key,root,n,m,q,k,mxd; 6 int sqr(int x) 7 { 8 return x*x; 9 } 10 11 struct point 12 { 13 int d[7]; 14 friend int dis(point a,point b) 15 { 16 int s=0; 17 for (int i=0; i<k; i++) 18 s+=sqr(a.d[i]-b.d[i]); 19 return s; 20 } 21 } po; 22 23 struct node 24 { 25 point nw; 26 int son[2],mi[7],mx[7]; 27 friend bool operator <(node a,node b) 28 { 29 return a.nw.d[key]<b.nw.d[key]; 30 } 31 }; 32 33 struct li 34 { 35 point a; int l; 36 friend bool operator <(li a, li b) 37 { 38 return a.l<b.l; 39 } 40 } mx; 41 set<li> st; 42 43 struct kdtree 44 { 45 node a[500010]; 46 void init() 47 { 48 49 a[0].son[0]=a[0].son[1]=0; 50 for (int i=0; i<5; i++) 51 { 52 a[0].mx[i]=-inf; 53 a[0].mi[i]=inf; 54 } 55 } 56 void update(int x) 57 { 58 int l=a[x].son[0],r=a[x].son[1]; 59 for (int i=0; i<k; i++) 60 { 61 a[x].mi[i]=min(a[x].nw.d[i],min(a[l].mi[i],a[r].mi[i])); 62 a[x].mx[i]=max(a[x].nw.d[i],max(a[l].mx[i],a[r].mx[i])); 63 } 64 } 65 int build(int l,int r,int cur) 66 { 67 if (l>r) return 0; 68 int m=(l+r)>>1; 69 key=cur; nth_element(a+l,a+m,a+r+1); 70 a[m].son[0]=build(l,m-1,(cur+1)%k); 71 a[m].son[1]=build(m+1,r,(cur+1)%k); 72 update(m); 73 return m; 74 } 75 int getmi(int x) 76 { 77 int s=0; 78 for (int i=0; i<k; i++) 79 s+=sqr(max(po.d[i]-a[x].mx[i],0)+max(a[x].mi[i]-po.d[i],0)); 80 return s; 81 } 82 void ask(int q) 83 { 84 if (!q) return; 85 int tmp=dis(a[q].nw,po); 86 st.insert((li){a[q].nw,tmp}); 87 if (st.size()>m) 88 { 89 set<li>::iterator it=st.end(); it--; 90 st.erase(it); 91 } 92 mxd=(*st.rbegin()).l; 93 int l=a[q].son[0],r=a[q].son[1],dl=2147483647,dr=2147483647; 94 if (l) dl=getmi(l); 95 if (r) dr=getmi(r); 96 if (dl<dr) 97 { 98 if (dl<mxd||st.size()<m) ask(l); 99 if (dr<mxd||st.size()<m) ask(r); 100 } 101 else { 102 if (dr<mxd||st.size()<m) ask(r); 103 if (dl<mxd||st.size()<m) ask(l); 104 } 105 } 106 } kd; 107 108 int main() 109 { 110 while (scanf("%d%d",&n,&k)!=EOF) 111 { 112 kd.init(); 113 for (int i=1; i<=n; i++) 114 for (int j=0; j<k; j++) 115 scanf("%d",&kd.a[i].nw.d[j]); 116 root=kd.build(1,n,0); 117 scanf("%d",&q); 118 while (q--) 119 { 120 for (int i=0; i<k; i++) 121 scanf("%d",&po.d[i]); 122 scanf("%d",&m); 123 st.clear(); 124 kd.ask(root); 125 printf("the closest %d points are:\n",m); 126 for (set<li>::iterator it=st.begin(); it!=st.end(); it++) 127 { 128 point ans=(*it).a; 129 for (int i=0; i<k; i++) 130 { 131 printf("%d",ans.d[i]); 132 if (i!=k-1) printf(" "); else puts(""); 133 } 134 } 135 } 136 } 137 }