模板:KD-Tree

KD-Tree,用来解决多维空间中的问题,其实就是优化暴力(逃

一般cdq能做的它都能做,而且。。。既然是优化暴力,那就学习一下了

对与几个n维点,我们将它每一维分割,建立一颗二叉树,方便我们搜索剪枝

它好像插入比较麻烦,和替罪羊一样暴力重构,博主蒟蒻不会啦

KD-Tree能解决的问题:平面上点对最小,最大距离,k大距离(包括曼哈顿距离和欧式距离)

当然这是我的理解,可能会有偏差

KD-Tree基本模板:

struct node{
	int d[2],l,r,mx[2],mn[2],id;
	friend bool operator < (node a,node b){
		return a.d[now]<b.d[now];
	}
}ask,tr[MAXN];
struct KD_TREE{
	void update(int x){
		int l=tr[x].l,r=tr[x].r;
		for(int i=0;i<=1;i++){
			tr[x].mn[i]=tr[x].mx[i]=tr[x].d[i];
			if(l!=0){
				tr[x].mn[i]=min(tr[x].mn[i],tr[l].mn[i]);
				tr[x].mx[i]=max(tr[x].mx[i],tr[l].mx[i]);
			}
			if(r!=0){
				tr[x].mn[i]=min(tr[x].mn[i],tr[r].mn[i]);
				tr[x].mx[i]=max(tr[x].mx[i],tr[r].mx[i]);
			}
		}
	}
	int dis(node a,node b){
		int res=0;
		for(int i=0;i<=1;i++){
			res+=power(a.d[i]-b.d[i]);
		}
		return res;
	}
	int get_dis(node a){
		int res=0;
		for(int i=0;i<=1;i++){
			res+=max(power(a.mx[i]-ask.d[i]),power(a.mn[i]-ask.d[i]));
		}
		return res;
	}
	void build(int &rt,int l,int r,int d){
		int mid=(l+r)>>1;
		now=d;
		nth_element(tr+l,tr+mid,tr+r+1);
		if(l<mid) build(tr[mid].l,l,mid-1,d^1);
		if(r>mid) build(tr[mid].r,mid+1,r,d^1);
		update(mid);
		rt=mid;
	}
	void query(int x){
		if(!x) return ;
		int sum_l=inf,sum_r=inf,dist=dis(tr[x],ask);
		if(tr[x].l) sum_l=get_dis(tr[tr[x].l]);
		if(tr[x].r) sum_r=get_dis(tr[tr[x].r]);
		if(dist>-q.top()){
			q.pop();
			q.push(-dist);
		}
		if(sum_l>sum_r){
			if(sum_l>=-q.top()) query(tr[x].l);
			if(sum_r>=-q.top()) query(tr[x].r);
		}else{
			if(sum_r>=-q.top()) query(tr[x].r);
			if(sum_l>=-q.top()) query(tr[x].l);
		}
	}
}KD_Tree;

其中dis和get_dis是随题而定,其他的基本不变

 

例题:Hide and Seek

题目大意:给出平面内几个点的坐标,求曼哈顿距离最小值

挺水的,上代码:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define MAXN 500005
#define inf 0x7fffffff
using namespace std;
int read(){
	int x=0;char ch=getchar();
	while(ch<'0'||ch>'9'){ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
	return x;
}
int max(int a,int b){return a>b?a:b;}
int min(int a,int b){return a<b?a:b;}
int abs(int a){return a<0?-a:a;}
int n,now,root,ans=inf;
struct node{
	int d[2],l,r,mx[2],mn[2];
	friend bool operator < (node a,node b){
		return a.d[now]<b.d[now];
	}
	friend int dis(node a,node b){
		return abs(a.d[0]-b.d[0])+abs(a.d[1]-b.d[1]);
	}
}tr[MAXN];
struct KD_TREE{
	node p[MAXN],t;
	int ans;
	private:
		void update(int x){
			int l=p[x].l,r=p[x].r;
			for(int i=0;i<=1;i++){
				if(l!=0){
					p[x].mn[i]=min(p[x].mn[i],p[l].mn[i]);
					p[x].mx[i]=max(p[x].mx[i],p[l].mx[i]);
				}
				if(r!=0){
					p[x].mn[i]=min(p[x].mn[i],p[r].mn[i]);
					p[x].mx[i]=max(p[x].mx[i],p[r].mx[i]);
				}
			}
		}
		int get_min(node x){
			int sum=0;
			for(int i=0;i<=1;i++){
				sum+=max(x.mn[i]-t.d[i],0);
            	sum+=max(t.d[i]-x.mx[i],0); 
			}
			return sum;
		}
		int get_max(node x){
			int sum=0;
			for(int i=0;i<=1;i++){
				sum+=max(abs(x.mn[i]-t.d[i]),abs(x.mx[i]-t.d[i]));
			}
			return sum;
		}
	public:
		void build(int &rt,int l,int r,int d){
			int mid=l+r>>1;
			now=d;
			nth_element(tr+l,tr+mid,tr+r+1);
			p[mid]=tr[mid];
			for(int i=0;i<=1;i++)
            	p[mid].mx[i]=p[mid].mn[i]=p[mid].d[i];
			if(l<mid) build(p[mid].l,l,mid-1,d^1);
			if(r>mid) build(p[mid].r,mid+1,r,d^1);
			update(mid);
			rt=mid;
		}
		void query_min(int k){
			int dist=dis(p[k],t);
			if(dist) ans=min(ans,dist);
			int sum_l=inf,sum_r=inf;
			if(p[k].l) sum_l=get_min(p[p[k].l]);
			if(p[k].r) sum_r=get_min(p[p[k].r]);
			if(sum_l>sum_r){
				if(sum_r<ans) query_min(p[k].r);
				if(sum_l<ans) query_min(p[k].l);
			}else{
				if(sum_l<ans) query_min(p[k].l);
				if(sum_r<ans) query_min(p[k].r);
			}
		}
		void query_max(int k){
			ans=max(ans,dis(p[k],t));
			int sum_l=-inf,sum_r=-inf;
			if(p[k].l) sum_l=get_max(p[p[k].l]);
			if(p[k].r) sum_r=get_max(p[p[k].r]);
			if(sum_l>sum_r){
				if(sum_l>ans) query_max(p[k].l);
				if(sum_r>ans) query_max(p[k].r);
			}else{
				if(sum_r>ans) query_max(p[k].r);
				if(sum_l>ans) query_max(p[k].l);
			}
		}
}KD_Tree;
int main(){
	n=read();
	for(int i=1;i<=n;i++){
		tr[i].d[0]=read();
		tr[i].d[1]=read();
	}
	KD_Tree.build(root,1,n,0);
	for(int i=1;i<=n;i++){
		KD_Tree.t=tr[i];
		KD_Tree.ans=inf;
		KD_Tree.query_min(root);
		int minn=KD_Tree.ans;
		KD_Tree.ans=-inf;
		KD_Tree.query_max(root);
		int maxx=KD_Tree.ans;
		ans=min(ans,maxx-minn);
	}
	printf("%d\n",ans);
	return 0;
}

例题:JZPFAR:

题目大意:给定平面上n个点坐标以及m次询问,每一次输出欧式距离距目标点第k大的点的标号

跟上一个差不多,估价函数变了,

第k大的话,维护一个有k个元素的堆,每次有更优的就pop队顶,最终top就是答案

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<queue>
#define MAXN 500005
#define inf 0x7fffffff
#define int long long
using namespace std;
int read(){
	int x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-') f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
	return x*f;
}
int max(int a,int b){return a>b?a:b;}
int min(int a,int b){return a<b?a:b;}
int abs(int a){return a<0?-a:a;}
int power(int a){return a*a;}
int n,m,now,root,k;
struct data{
	int dis,id;
	friend bool operator < (data a,data b){
		return a.dis==b.dis?a.id<a.id:a.dis>b.dis;
	}
};
priority_queue< data > q;
struct node{
	int d[2],l,r,mx[2],mn[2],id;
	friend bool operator < (node a,node b){
		return a.d[now]<b.d[now];
	}
}ask,tr[MAXN],p[MAXN];
struct KD_TREE{
	private:
		void update(int x){
			int l=tr[x].l,r=tr[x].r;
			for(int i=0;i<=1;i++){
				tr[x].mn[i]=tr[x].mx[i]=tr[x].d[i];
				if(l!=0){
					tr[x].mn[i]=min(tr[x].mn[i],tr[l].mn[i]);
					tr[x].mx[i]=max(tr[x].mx[i],tr[l].mx[i]);
				}
				if(r!=0){
					tr[x].mn[i]=min(tr[x].mn[i],tr[r].mn[i]);
					tr[x].mx[i]=max(tr[x].mx[i],tr[r].mx[i]);
				}
			}
		}
		int dis(node a,node b){
			int res=0;
			for(int i=0;i<=1;i++){
				res+=power(a.d[i]-b.d[i]);
			}
			return res;
		}
		int calc(int x){
			if(x==0) return -2;
			int res=0;
			for(int i=0;i<=1;i++){
				res+=max(power(tr[x].mx[i]-ask.d[i]),power(tr[x].mn[i]-ask.d[i]));
			}
			return res;
		}
	public:
		void build(int &rt,int l,int r,int d){
			int mid=(l+r)>>1;
			now=d;
			nth_element(p+l,p+mid,p+r+1);
			tr[mid]=p[mid];
			if(l<mid) build(tr[mid].l,l,mid-1,d^1);
			if(r>mid) build(tr[mid].r,mid+1,r,d^1);
			update(mid);
			rt=mid;
		}
		void query(int x){
			//cout<<x<<endl;
			if(!x) return ;
			int sum_l=calc(tr[x].l),sum_r=calc(tr[x].r),dist=dis(tr[x],ask);
			//cout<<sum_l<<' '<<sum_r<<' '<<dist<<endl;
			if(dist>q.top().dis||(dist==q.top().dis&&tr[x].id<q.top().id)){
				q.pop();
				//cout<<dist<<' '<<tr[x].id<<endl;
				q.push((data){dist,tr[x].id});
			}
			if(sum_l>sum_r){
				if(sum_l>=q.top().dis) query(tr[x].l);
				if(sum_r>=q.top().dis) query(tr[x].r);
			}else{
				if(sum_r>=q.top().dis) query(tr[x].r);
				if(sum_l>=q.top().dis) query(tr[x].l);
			}
		}
}KD_Tree;
signed main(){
	n=read();
	for(int i=1;i<=n;i++){
		p[i].d[0]=read();
		p[i].d[1]=read();
		p[i].id=i;
	}
	KD_Tree.build(root,1,n,0);
	//cout<<root<<endl;
	m=read();
	for(int i=1;i<=m;i++){
		ask.d[0]=read();
		ask.d[1]=read();
		k=read();
		while(!q.empty()) q.pop();
		for(int i=1;i<=k;i++)
			q.push((data){-1,0});
		//	q.push(make_pair(-1,0));
		KD_Tree.query(root);
		//cout<<q.top().first<<endl;
		printf("%lld\n",q.top().id);
	}
	return 0;
}

进阶:K远点对

题目大意:已知平面内 N 个点的坐标,求欧氏距离下的第 K 远点对。

和上一个一样,只不过这一次我们要对每个点query一遍,

其中维护一个2k的堆(因为会重复算),最终堆顶就是答案

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<queue>
#define MAXN 100005
#define inf 0x7fffffff
#define int long long
using namespace std;
int read(){
	int x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-') f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
	return x*f;
}
int max(int a,int b){return a>b?a:b;}
int min(int a,int b){return a<b?a:b;}
int abs(int a){return a<0?-a:a;}
int power(int a){return a*a;}
int n,m,now,root,k;
struct data{
	int dis;
	friend bool operator < (data a,data b){
		return a.dis<b.dis;
	}
};
priority_queue<int> q;
struct node{
	int d[2],l,r,mx[2],mn[2],id;
	friend bool operator < (node a,node b){
		return a.d[now]<b.d[now];
	}
}ask,tr[MAXN];
struct KD_TREE{
	void update(int x){
		int l=tr[x].l,r=tr[x].r;
		for(int i=0;i<=1;i++){
			tr[x].mn[i]=tr[x].mx[i]=tr[x].d[i];
			if(l!=0){
				tr[x].mn[i]=min(tr[x].mn[i],tr[l].mn[i]);
				tr[x].mx[i]=max(tr[x].mx[i],tr[l].mx[i]);
			}
			if(r!=0){
				tr[x].mn[i]=min(tr[x].mn[i],tr[r].mn[i]);
				tr[x].mx[i]=max(tr[x].mx[i],tr[r].mx[i]);
			}
		}
	}
	int dis(node a,node b){
		int res=0;
		for(int i=0;i<=1;i++){
			res+=power(a.d[i]-b.d[i]);
		}
		return res;
	}
	int get_dis(node a){
		int res=0;
		for(int i=0;i<=1;i++){
			res+=max(power(a.mx[i]-ask.d[i]),power(a.mn[i]-ask.d[i]));
		}
		return res;
	}
	void build(int &rt,int l,int r,int d){
		int mid=(l+r)>>1;
		now=d;
		nth_element(tr+l,tr+mid,tr+r+1);
		if(l<mid) build(tr[mid].l,l,mid-1,d^1);
		if(r>mid) build(tr[mid].r,mid+1,r,d^1);
		update(mid);
		rt=mid;
	}
	void query(int x){
		if(!x) return ;
		int sum_l=inf,sum_r=inf,dist=dis(tr[x],ask);
		if(tr[x].l) sum_l=get_dis(tr[tr[x].l]);
		if(tr[x].r) sum_r=get_dis(tr[tr[x].r]);
		if(dist>-q.top()){
			q.pop();
			q.push(-dist);
		}
		if(sum_l>sum_r){
			if(sum_l>=-q.top()) query(tr[x].l);
			if(sum_r>=-q.top()) query(tr[x].r);
		}else{
			if(sum_r>=-q.top()) query(tr[x].r);
			if(sum_l>=-q.top()) query(tr[x].l);
		}
	}
}KD_Tree;
signed main(){
	n=read(),k=read();
	for(int i=1;i<=n;i++){
		tr[i].d[0]=read();
		tr[i].d[1]=read();
		tr[i].id=i;
	}
	KD_Tree.build(root,1,n,0);
	for(int i=1;i<=2*k;i++) q.push(inf);
	for(int i=1;i<=n;i++){
		ask=tr[i];
		KD_Tree.query(root);
	}
	printf("%lld\n",-q.top());
	return 0;
}

 

posted @ 2019-08-06 17:00  xukl21  阅读(198)  评论(0编辑  收藏  举报