C19【模板】KD 树 交替建树
视频链接:241 kd 树 交替建树_哔哩哔哩_bilibili
// 交替建树 970ms #include <cstdio> #include <cstring> #include <algorithm> #include <cmath> #define lc t[p].l #define rc t[p].r using namespace std; const int N=200010; double ans=2e18; int n,K,root,cur; //K维度,root根,cur当前节点 struct KD{ //KD树节点信息 int l,r; //左右孩子 double v[2]; //点的坐标值 double L[2],U[2]; //子树区域的坐标范围 bool operator<(const KD &b)const{return v[K]<b.v[K];} }t[N]; void pushup(int p){ //更新p子树区域的坐标范围 for(int i=0;i<2;i++){ t[p].L[i]=t[p].U[i]=t[p].v[i]; if(lc) t[p].L[i]=min(t[p].L[i],t[lc].L[i]), t[p].U[i]=max(t[p].U[i],t[lc].U[i]); if(rc) t[p].L[i]=min(t[p].L[i],t[rc].L[i]), t[p].U[i]=max(t[p].U[i],t[rc].U[i]); } } int build(int l,int r,int k){ //交替建树 if(l>r) return 0; int m=(l+r)>>1; K=k; nth_element(t+l,t+m,t+r+1); //中位数 t[m].l=build(l,m-1,k^1); t[m].r=build(m+1,r,k^1); pushup(m); return m; } double sq(double x){return x*x;} double dis(int p){ //当前点到p点的距离 double s=0; for(int i=0;i<2;i++) s+=sq(t[cur].v[i]-t[p].v[i]); return s; } double dis2(int p){ //当前点到p子树区域的最小距离 if(!p) return 2e18; double s=0; for(int i=0;i<2;++i) s+=sq(max(t[cur].v[i]-t[p].U[i],0.0))+ sq(max(t[p].L[i]-t[cur].v[i],0.0)); return s; } void query(int p){ //查询当前点的最小距离 if(!p) return; if(p!=cur) ans=min(ans,dis(p)); double dl=dis2(lc),dr=dis2(rc); if(dl<dr){ if(dl<ans) query(lc); if(dr<ans) query(rc); } else{ if(dr<ans) query(rc); if(dl<ans) query(lc); } } int main(){ scanf("%d",&n); for(int i=1; i<=n; i++) scanf("%lf%lf",&t[i].v[0],&t[i].v[1]); root=build(1,n,0); for(cur=1; cur<=n; cur++) query(root); printf("%.4lf\n",sqrt(ans)); }
// 方差建树 1.1s #include <cstdio> #include <cstring> #include <algorithm> #include <cmath> #define lc t[p].l #define rc t[p].r using namespace std; const int N=200010; double ans=2e18; int n,K,root,cur; //K维度,root根,cur当前节点 struct KD{ //KD树节点信息 int l,r; //左右孩子 double v[2]; //点的坐标值 double L[2],U[2]; //子树区域的坐标范围 bool operator<(const KD &b)const{return v[K]<b.v[K];} }t[N]; double sq(double x){return x*x;} void pushup(int p){ //更新p子树区域的坐标范围 for(int i=0;i<2;i++){ t[p].L[i]=t[p].U[i]=t[p].v[i]; if(lc) t[p].L[i]=min(t[p].L[i],t[lc].L[i]), t[p].U[i]=max(t[p].U[i],t[lc].U[i]); if(rc) t[p].L[i]=min(t[p].L[i],t[rc].L[i]), t[p].U[i]=max(t[p].U[i],t[rc].U[i]); } } int build(int l,int r){ //方差建树 if(l>r) return 0; int m=(l+r)>>1; double av[2]={0},va[2]={0}; //average, variance for(int i=l; i<=r; i++) for(int j=0; j<2; j++) av[j]+=t[i].v[j]; av[0]/=r-l+1; av[1]/=r-l+1; //平均值 for(int i=l; i<=r; i++) for(int j=0; j<2; j++) va[j]+=sq(t[i].v[j]-av[j]); //方差 K=va[0]<va[1]; //方差大,离散度大,分割最优 nth_element(t+l,t+m,t+r+1); t[m].l=build(l,m-1); t[m].r=build(m+1,r); pushup(m); return m; } double dis(int p){ //当前点到p点的距离 double s=0; for(int i=0;i<2;i++) s+=sq(t[cur].v[i]-t[p].v[i]); return s; } double dis2(int p){ //当前点到p子树区域的最小距离 if(!p) return 2e18; double s=0; for(int i=0;i<2;++i) s+=sq(max(t[cur].v[i]-t[p].U[i],0.0))+ sq(max(t[p].L[i]-t[cur].v[i],0.0)); return s; } void query(int p){ //查询当前点的最小距离 if(!p) return; if(p!=cur) ans=min(ans,dis(p)); double dl=dis2(lc),dr=dis2(rc); if(dl<dr){ if(dl<ans) query(lc); if(dr<ans) query(rc); } else{ if(dr<ans) query(rc); if(dl<ans) query(lc); } } int main(){ scanf("%d",&n); for(int i=1; i<=n; i++) scanf("%lf%lf",&t[i].v[0],&t[i].v[1]); root=build(1,n); for(cur=1; cur<=n; cur++) query(root); printf("%.4lf\n",sqrt(ans)); }
练习:
// TLE#13,正解旋转卡壳 #include <algorithm> #include <cstring> #include <iostream> #include <cmath> #include <queue> #define lc t[p].l #define rc t[p].r using namespace std; typedef long long LL; const int N=100010; int n,k,K,root,cur; //K维度,root根,cur当前节点 struct KD{ //KD树节点信息 int l,r; //左右孩子 LL v[2]; //点的坐标值 LL L[2],U[2]; //子树区域的坐标范围 bool operator<(const KD &b)const{return v[K]<b.v[K];} }t[N]; priority_queue<LL,vector<LL>,greater<LL> >q; //小根堆 void pushup(int p){ //更新p子树区域的坐标范围 for(int i=0;i<2;i++){ t[p].L[i]=t[p].U[i]=t[p].v[i]; if(lc) t[p].L[i]=min(t[p].L[i],t[lc].L[i]), t[p].U[i]=max(t[p].U[i],t[lc].U[i]); if(rc) t[p].L[i]=min(t[p].L[i],t[rc].L[i]), t[p].U[i]=max(t[p].U[i],t[rc].U[i]); } } int build(int l,int r,int k){ //交替建树 if(l>r) return 0; int m=(l+r)>>1; K=k; nth_element(t+l,t+m,t+r+1); //中位数 t[m].l=build(l,m-1,k^1); t[m].r=build(m+1,r,k^1); pushup(m); return m; } LL sq(LL x){return x*x;} LL dis(int p){ //当前点到p点的距离 LL s=0; for(int i=0;i<2;i++) s+=sq(t[cur].v[i]-t[p].v[i]); return s; } LL dis2(int p){ //当前点到p子树区域的最大距离 if(!p) return 0; LL s=0; for(int i=0;i<2;++i) s+=sq(max(t[cur].v[i]-t[p].L[i],1ll*0))+ sq(max(t[p].U[i]-t[cur].v[i],1ll*0)); return s; } void query(int p){ //查询第K远点对的距离 if(!p) return; if(p==cur) return; LL d=dis(p); if(d>q.top()) q.pop(),q.push(d); LL dl=dis2(lc), dr=dis2(rc); if(dl>dr){ if(dl>q.top()) query(lc); if(dr>q.top()) query(rc); } else{ if(dr>q.top()) query(rc); if(dl>q.top()) query(lc); } } int main(){ scanf("%d%d",&n,&k); k*=2; //因为两点距离算了两遍 for(int i=1; i<=k; i++) q.push(0); for(int i=1; i<=n; i++) scanf("%lld%lld",&t[i].v[0],&t[i].v[1]); root=build(1,n,0); for(cur=1; cur<=n; cur++) query(root); printf("%lld\n",q.top()); }