模板:KD-Tree
KD-Tree,用来解决多维空间中的问题,其实就是优化暴力(逃
一般cdq能做的它都能做,而且。。。既然是优化暴力,那就学习一下了
对与几个n维点,我们将它每一维分割,建立一颗二叉树,方便我们搜索剪枝
它好像插入比较麻烦,和替罪羊一样暴力重构,博主蒟蒻不会啦
KD-Tree能解决的问题:平面上点对最小,最大距离,k大距离(包括曼哈顿距离和欧式距离)
当然这是我的理解,可能会有偏差
KD-Tree基本模板:
struct node{ int d[2],l,r,mx[2],mn[2],id; friend bool operator < (node a,node b){ return a.d[now]<b.d[now]; } }ask,tr[MAXN]; struct KD_TREE{ void update(int x){ int l=tr[x].l,r=tr[x].r; for(int i=0;i<=1;i++){ tr[x].mn[i]=tr[x].mx[i]=tr[x].d[i]; if(l!=0){ tr[x].mn[i]=min(tr[x].mn[i],tr[l].mn[i]); tr[x].mx[i]=max(tr[x].mx[i],tr[l].mx[i]); } if(r!=0){ tr[x].mn[i]=min(tr[x].mn[i],tr[r].mn[i]); tr[x].mx[i]=max(tr[x].mx[i],tr[r].mx[i]); } } } int dis(node a,node b){ int res=0; for(int i=0;i<=1;i++){ res+=power(a.d[i]-b.d[i]); } return res; } int get_dis(node a){ int res=0; for(int i=0;i<=1;i++){ res+=max(power(a.mx[i]-ask.d[i]),power(a.mn[i]-ask.d[i])); } return res; } void build(int &rt,int l,int r,int d){ int mid=(l+r)>>1; now=d; nth_element(tr+l,tr+mid,tr+r+1); if(l<mid) build(tr[mid].l,l,mid-1,d^1); if(r>mid) build(tr[mid].r,mid+1,r,d^1); update(mid); rt=mid; } void query(int x){ if(!x) return ; int sum_l=inf,sum_r=inf,dist=dis(tr[x],ask); if(tr[x].l) sum_l=get_dis(tr[tr[x].l]); if(tr[x].r) sum_r=get_dis(tr[tr[x].r]); if(dist>-q.top()){ q.pop(); q.push(-dist); } if(sum_l>sum_r){ if(sum_l>=-q.top()) query(tr[x].l); if(sum_r>=-q.top()) query(tr[x].r); }else{ if(sum_r>=-q.top()) query(tr[x].r); if(sum_l>=-q.top()) query(tr[x].l); } } }KD_Tree;
其中dis和get_dis是随题而定,其他的基本不变
例题:Hide and Seek
题目大意:给出平面内几个点的坐标,求曼哈顿距离最小值
挺水的,上代码:
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<cmath> #define MAXN 500005 #define inf 0x7fffffff using namespace std; int read(){ int x=0;char ch=getchar(); while(ch<'0'||ch>'9'){ch=getchar();} while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();} return x; } int max(int a,int b){return a>b?a:b;} int min(int a,int b){return a<b?a:b;} int abs(int a){return a<0?-a:a;} int n,now,root,ans=inf; struct node{ int d[2],l,r,mx[2],mn[2]; friend bool operator < (node a,node b){ return a.d[now]<b.d[now]; } friend int dis(node a,node b){ return abs(a.d[0]-b.d[0])+abs(a.d[1]-b.d[1]); } }tr[MAXN]; struct KD_TREE{ node p[MAXN],t; int ans; private: void update(int x){ int l=p[x].l,r=p[x].r; for(int i=0;i<=1;i++){ if(l!=0){ p[x].mn[i]=min(p[x].mn[i],p[l].mn[i]); p[x].mx[i]=max(p[x].mx[i],p[l].mx[i]); } if(r!=0){ p[x].mn[i]=min(p[x].mn[i],p[r].mn[i]); p[x].mx[i]=max(p[x].mx[i],p[r].mx[i]); } } } int get_min(node x){ int sum=0; for(int i=0;i<=1;i++){ sum+=max(x.mn[i]-t.d[i],0); sum+=max(t.d[i]-x.mx[i],0); } return sum; } int get_max(node x){ int sum=0; for(int i=0;i<=1;i++){ sum+=max(abs(x.mn[i]-t.d[i]),abs(x.mx[i]-t.d[i])); } return sum; } public: void build(int &rt,int l,int r,int d){ int mid=l+r>>1; now=d; nth_element(tr+l,tr+mid,tr+r+1); p[mid]=tr[mid]; for(int i=0;i<=1;i++) p[mid].mx[i]=p[mid].mn[i]=p[mid].d[i]; if(l<mid) build(p[mid].l,l,mid-1,d^1); if(r>mid) build(p[mid].r,mid+1,r,d^1); update(mid); rt=mid; } void query_min(int k){ int dist=dis(p[k],t); if(dist) ans=min(ans,dist); int sum_l=inf,sum_r=inf; if(p[k].l) sum_l=get_min(p[p[k].l]); if(p[k].r) sum_r=get_min(p[p[k].r]); if(sum_l>sum_r){ if(sum_r<ans) query_min(p[k].r); if(sum_l<ans) query_min(p[k].l); }else{ if(sum_l<ans) query_min(p[k].l); if(sum_r<ans) query_min(p[k].r); } } void query_max(int k){ ans=max(ans,dis(p[k],t)); int sum_l=-inf,sum_r=-inf; if(p[k].l) sum_l=get_max(p[p[k].l]); if(p[k].r) sum_r=get_max(p[p[k].r]); if(sum_l>sum_r){ if(sum_l>ans) query_max(p[k].l); if(sum_r>ans) query_max(p[k].r); }else{ if(sum_r>ans) query_max(p[k].r); if(sum_l>ans) query_max(p[k].l); } } }KD_Tree; int main(){ n=read(); for(int i=1;i<=n;i++){ tr[i].d[0]=read(); tr[i].d[1]=read(); } KD_Tree.build(root,1,n,0); for(int i=1;i<=n;i++){ KD_Tree.t=tr[i]; KD_Tree.ans=inf; KD_Tree.query_min(root); int minn=KD_Tree.ans; KD_Tree.ans=-inf; KD_Tree.query_max(root); int maxx=KD_Tree.ans; ans=min(ans,maxx-minn); } printf("%d\n",ans); return 0; }
例题:JZPFAR:
题目大意:给定平面上n个点坐标以及m次询问,每一次输出欧式距离距目标点第k大的点的标号
跟上一个差不多,估价函数变了,
第k大的话,维护一个有k个元素的堆,每次有更优的就pop队顶,最终top就是答案
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<cmath> #include<queue> #define MAXN 500005 #define inf 0x7fffffff #define int long long using namespace std; int read(){ int x=0,f=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-') f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();} return x*f; } int max(int a,int b){return a>b?a:b;} int min(int a,int b){return a<b?a:b;} int abs(int a){return a<0?-a:a;} int power(int a){return a*a;} int n,m,now,root,k; struct data{ int dis,id; friend bool operator < (data a,data b){ return a.dis==b.dis?a.id<a.id:a.dis>b.dis; } }; priority_queue< data > q; struct node{ int d[2],l,r,mx[2],mn[2],id; friend bool operator < (node a,node b){ return a.d[now]<b.d[now]; } }ask,tr[MAXN],p[MAXN]; struct KD_TREE{ private: void update(int x){ int l=tr[x].l,r=tr[x].r; for(int i=0;i<=1;i++){ tr[x].mn[i]=tr[x].mx[i]=tr[x].d[i]; if(l!=0){ tr[x].mn[i]=min(tr[x].mn[i],tr[l].mn[i]); tr[x].mx[i]=max(tr[x].mx[i],tr[l].mx[i]); } if(r!=0){ tr[x].mn[i]=min(tr[x].mn[i],tr[r].mn[i]); tr[x].mx[i]=max(tr[x].mx[i],tr[r].mx[i]); } } } int dis(node a,node b){ int res=0; for(int i=0;i<=1;i++){ res+=power(a.d[i]-b.d[i]); } return res; } int calc(int x){ if(x==0) return -2; int res=0; for(int i=0;i<=1;i++){ res+=max(power(tr[x].mx[i]-ask.d[i]),power(tr[x].mn[i]-ask.d[i])); } return res; } public: void build(int &rt,int l,int r,int d){ int mid=(l+r)>>1; now=d; nth_element(p+l,p+mid,p+r+1); tr[mid]=p[mid]; if(l<mid) build(tr[mid].l,l,mid-1,d^1); if(r>mid) build(tr[mid].r,mid+1,r,d^1); update(mid); rt=mid; } void query(int x){ //cout<<x<<endl; if(!x) return ; int sum_l=calc(tr[x].l),sum_r=calc(tr[x].r),dist=dis(tr[x],ask); //cout<<sum_l<<' '<<sum_r<<' '<<dist<<endl; if(dist>q.top().dis||(dist==q.top().dis&&tr[x].id<q.top().id)){ q.pop(); //cout<<dist<<' '<<tr[x].id<<endl; q.push((data){dist,tr[x].id}); } if(sum_l>sum_r){ if(sum_l>=q.top().dis) query(tr[x].l); if(sum_r>=q.top().dis) query(tr[x].r); }else{ if(sum_r>=q.top().dis) query(tr[x].r); if(sum_l>=q.top().dis) query(tr[x].l); } } }KD_Tree; signed main(){ n=read(); for(int i=1;i<=n;i++){ p[i].d[0]=read(); p[i].d[1]=read(); p[i].id=i; } KD_Tree.build(root,1,n,0); //cout<<root<<endl; m=read(); for(int i=1;i<=m;i++){ ask.d[0]=read(); ask.d[1]=read(); k=read(); while(!q.empty()) q.pop(); for(int i=1;i<=k;i++) q.push((data){-1,0}); // q.push(make_pair(-1,0)); KD_Tree.query(root); //cout<<q.top().first<<endl; printf("%lld\n",q.top().id); } return 0; }
进阶:K远点对
题目大意:已知平面内 N 个点的坐标,求欧氏距离下的第 K 远点对。
和上一个一样,只不过这一次我们要对每个点query一遍,
其中维护一个2k的堆(因为会重复算),最终堆顶就是答案
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<cmath> #include<queue> #define MAXN 100005 #define inf 0x7fffffff #define int long long using namespace std; int read(){ int x=0,f=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-') f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();} return x*f; } int max(int a,int b){return a>b?a:b;} int min(int a,int b){return a<b?a:b;} int abs(int a){return a<0?-a:a;} int power(int a){return a*a;} int n,m,now,root,k; struct data{ int dis; friend bool operator < (data a,data b){ return a.dis<b.dis; } }; priority_queue<int> q; struct node{ int d[2],l,r,mx[2],mn[2],id; friend bool operator < (node a,node b){ return a.d[now]<b.d[now]; } }ask,tr[MAXN]; struct KD_TREE{ void update(int x){ int l=tr[x].l,r=tr[x].r; for(int i=0;i<=1;i++){ tr[x].mn[i]=tr[x].mx[i]=tr[x].d[i]; if(l!=0){ tr[x].mn[i]=min(tr[x].mn[i],tr[l].mn[i]); tr[x].mx[i]=max(tr[x].mx[i],tr[l].mx[i]); } if(r!=0){ tr[x].mn[i]=min(tr[x].mn[i],tr[r].mn[i]); tr[x].mx[i]=max(tr[x].mx[i],tr[r].mx[i]); } } } int dis(node a,node b){ int res=0; for(int i=0;i<=1;i++){ res+=power(a.d[i]-b.d[i]); } return res; } int get_dis(node a){ int res=0; for(int i=0;i<=1;i++){ res+=max(power(a.mx[i]-ask.d[i]),power(a.mn[i]-ask.d[i])); } return res; } void build(int &rt,int l,int r,int d){ int mid=(l+r)>>1; now=d; nth_element(tr+l,tr+mid,tr+r+1); if(l<mid) build(tr[mid].l,l,mid-1,d^1); if(r>mid) build(tr[mid].r,mid+1,r,d^1); update(mid); rt=mid; } void query(int x){ if(!x) return ; int sum_l=inf,sum_r=inf,dist=dis(tr[x],ask); if(tr[x].l) sum_l=get_dis(tr[tr[x].l]); if(tr[x].r) sum_r=get_dis(tr[tr[x].r]); if(dist>-q.top()){ q.pop(); q.push(-dist); } if(sum_l>sum_r){ if(sum_l>=-q.top()) query(tr[x].l); if(sum_r>=-q.top()) query(tr[x].r); }else{ if(sum_r>=-q.top()) query(tr[x].r); if(sum_l>=-q.top()) query(tr[x].l); } } }KD_Tree; signed main(){ n=read(),k=read(); for(int i=1;i<=n;i++){ tr[i].d[0]=read(); tr[i].d[1]=read(); tr[i].id=i; } KD_Tree.build(root,1,n,0); for(int i=1;i<=2*k;i++) q.push(inf); for(int i=1;i<=n;i++){ ask=tr[i]; KD_Tree.query(root); } printf("%lld\n",-q.top()); return 0; }