K-D 树学习笔记

K-D树用于解决2维(或高维)平面上的一切线段树能解决的问题。

K-D树上的每个节点代表一个实际的点,而非一个区间。

K-D树的建树:以二维平面为例,K-D树首先选取当前点集中一个方差最大的维度,并在选出这个维度坐标为中位数的点,用 max_element(a+l,a+mid,a+r+1,cmp) 将它放到mid上,小于它的扔到左边,大于的放到右边,递归去做。

2-D树的查询复杂度是\(O(\sqrt n)\)的,层数为\(\log\)

支持修改的K-D树需要在 pushup 后判断当前点的左右子树大小是否平衡,适时进行暴力重构。

判断平衡的方法是令 \(α=0.725\),若 \(\max(siz[ls],siz[rs])>siz[x]·α\) 则调用rebuild(x)

rebuild 的方法为:首先把x的子树内的所有点取出来,然后 build

K-D树常能通过维护剪枝所需要的信息来辅助搜索,骗到可观的分数。

K-D树优化建图等与线段树类似的算法,它也可以完成。

代码模板:(P4148 简单题)

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
inline int read(){
	register char ch=getchar();register int x=0,f=1;
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
	return x*f;
}
void print(int x){
	if(x<0){putchar('-'),print(-x);return;}
	if(x/10)print(x/10);
	putchar(x%10+48);
}
const int N=2e5+5;
int n,las,tot,num,ls[N],rs[N],a[N],siz[N],L[N],R[N],U[N],D[N],type[N],sum[N];
double alpha=0.725;
struct node {
	int x,y,v;
}b[N];
void pushup(int x){
	siz[x]=siz[ls[x]]+siz[rs[x]]+1,sum[x]=b[x].v,L[x]=R[x]=b[x].x,U[x]=D[x]=b[x].y;
	if(ls[x])sum[x]+=sum[ls[x]],L[x]=min(L[x],L[ls[x]]),R[x]=max(R[x],R[ls[x]]),U[x]=max(U[x],U[ls[x]]),D[x]=min(D[x],D[ls[x]]);
	if(rs[x])sum[x]+=sum[rs[x]],L[x]=min(L[x],L[rs[x]]),R[x]=max(R[x],R[rs[x]]),U[x]=max(U[x],U[rs[x]]),D[x]=min(D[x],D[rs[x]]);
}bool cmp1(int i, int j) { return b[i].x < b[j].x; }

bool cmp2(int i, int j) { return b[i].y < b[j].y; }
int build(int l,int r){
	if(l>r)return 0;
	int mid=l+r>>1;
	double av1=0,av2=0,va1=0,va2=0;
	for(int i=l;i<=r;i++)av1+=b[a[i]].x,av2+=b[a[i]].y;
	av1/=r-l+1,av2/=r-l+1;
	for(int i=l;i<=r;i++)va1+=(b[a[i]].x-av1)*(b[a[i]].x-av1),va2+=(b[a[i]].y-av2)*(b[a[i]].y-av2);
	if(va1>va2)nth_element(a+l,a+mid,a+r+1,cmp1),type[a[mid]]=1;
	else nth_element(a+l,a+mid,a+r+1,cmp2),type[a[mid]]=2;
	ls[a[mid]]=build(l,mid-1),rs[a[mid]]=build(mid+1,r); 
	pushup(a[mid]);
	return a[mid];
}
void dfs(int x){//改成先序遍历也可以吧? 
	if(ls[x])dfs(ls[x]);
	a[++tot]=x;
	if(rs[x])dfs(rs[x]); 
}
void rebuild(int &x){
	tot=0,dfs(x),x=build(1,tot);
}
int X1,X2,Y1,Y2,ans;
void ask(int x){
	if(!x)return;
	if(L[x]>X2||R[x]<X1||D[x]>Y2||U[x]<Y1)return;
	if(L[x]>=X1&&R[x]<=X2&&D[x]>=Y1&&U[x]<=Y2){
		ans+=sum[x];
		return;
	} 
	if(b[x].x>=X1&&b[x].x<=X2&&b[x].y>=Y1&&b[x].y<=Y2)ans+=b[x].v;
	ask(ls[x]),ask(rs[x]);
}
int vx,vy,vv;
void insert(int &x){
	if(!x){
		x=++num,b[x].x=vx,b[x].y=vy,b[x].v=vv;
		pushup(x);
		return;
	}
	if(type[x]==1){
		if(vx<=b[x].x)insert(ls[x]);
		else insert(rs[x]);
	}
	else {
		if(vy<=b[x].y)insert(ls[x]);
		else insert(rs[x]);
	}
	pushup(x);
	if(siz[x]*alpha<=max(siz[ls[x]],siz[rs[x]]))rebuild(x);
}
int main(){
	n=read();
	int op,Rt=0;
	while(~scanf("%d",&op)){
		if(op==3)return 0;
		if(op==1){
			scanf("%d%d%d",&vx,&vy,&vv);
			vx^=las,vy^=las,vv^=las;
			insert(Rt);
		}
		else {
			scanf("%d%d%d%d",&X1,&Y1,&X2,&Y2);
			X1^=las,Y1^=las,X2^=las,Y2^=las;
			ans=0,ask(Rt);
			cout<<(las=ans)<<'\n';
		}
	}
	return 0;
}

经典例题:[NOI2019]弹跳

学完K-D树后,K-D树优化建图的思路就是显然的,需要注意的是每个点有K-D树节点一个,和自己一个,两个点分开连边,其中K-D树上的节点沿K-D树边连边,并连向对应的自己点,而每个自己点连向其对应弹跳装置对应的二维区间对应的K-D树上若干节点(这些点中有的是“节点”有的是“自己点”,因为是K-D树不是线段树的缘故,划分出的有整块也有散点),88分到手,MLE。

其实完全没必要提前把图建出来,而是到了一个点的时候,做同样的事情即可。由于使用的是dij,每个点只会取出一次,因此总复杂度是 \(O(m\sqrt n \log n)\) 的。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
inline int read(){
	register char ch=getchar();register int x=0,f=1;
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
	return x*f;
}
void print(int x){
	if(x<0){putchar('-'),print(-x);return;}
	if(x/10)print(x/10);
	putchar(x%10+48);
}
const int N=1.4e5+5;
int n,m,W,H,L[N],R[N],U[N],D[N],ls[N],rs[N],dis[N],dy[N];
priority_queue<pair<int,int> >Q;
struct node {
	int x,y,id;
}a[N];
void pushup(int x){
	L[x]=R[x]=a[x].x,U[x]=D[x]=a[x].y;
	if(ls[x])L[x]=min(L[x],L[ls[x]]),R[x]=max(R[x],R[ls[x]]),U[x]=max(U[x],U[ls[x]]),D[x]=min(D[x],D[ls[x]]);
	if(rs[x])L[x]=min(L[x],L[rs[x]]),R[x]=max(R[x],R[rs[x]]),U[x]=max(U[x],U[rs[x]]),D[x]=min(D[x],D[rs[x]]);
}
int build(int l,int r){
	if(l>r)return 0;
	int mid=l+r>>1;
	double av1=0,av2=0,va1=0,va2=0;
	for(int i=l;i<=r;i++)av1+=a[i].x,av2+=a[i].y;
	av1/=r-l+1,av2/=r-l+1;
	for(int i=l;i<=r;i++)va1+=(a[i].x-av1)*(a[i].x-av1),va2+=(a[i].y-av2)*(a[i].y-av2);
	if(va1>va2)nth_element(a+l,a+mid,a+r+1,[](node a,node b){return a.x<b.x;});
	else nth_element(a+l,a+mid,a+r+1,[](node a,node b){return a.y<b.y;});
	ls[mid]=build(l,mid-1),rs[mid]=build(mid+1,r);
	pushup(mid);
	dy[a[mid].id]=mid;
	return mid;
}
struct Edge{
	int t_t,ql,qr,qu,qd;
};
vector<Edge>vec[N];
void jian(int tar,int t_t,int ql,int qr,int qu,int qd,int x){
	if(!x)return;
	if(ql>R[x]||qr<L[x]||qu<D[x]||qd>U[x])return;
	if(dis[a[x].id+n]<dis[tar]+t_t)return;
	if(ql<=L[x]&&R[x]<=qr&&qd<=D[x]&&U[x]<=qu){
		if(dis[a[x].id+n]>dis[tar]+t_t){
			dis[a[x].id+n]=dis[tar]+t_t;
			Q.push(make_pair(-dis[a[x].id+n],a[x].id+n));
		}
		return;
	}
	if(ql<=a[x].x&&a[x].x<=qr&&qd<=a[x].y&&a[x].y<=qu){
		if(dis[a[x].id]>dis[tar]+t_t){
			dis[a[x].id]=dis[tar]+t_t;
			Q.push(make_pair(-dis[a[x].id],a[x].id));
		}
	}
	jian(tar,t_t,ql,qr,qu,qd,ls[x]),jian(tar,t_t,ql,qr,qu,qd,rs[x]);
}
int main(){
	n=read(),m=read(),W=read(),H=read();
	for(int i=1;i<=n;i++)a[i].x=read(),a[i].y=read(),a[i].id=i;
	build(1,n);
	for(int i=1,tar,t_t,ql,qr,qd,qu;i<=m;i++){
		tar=read(),t_t=read(),ql=read(),qr=read(),qd=read(),qu=read();
		vec[tar].emplace_back(Edge{t_t,ql,qr,qu,qd});
	}
	memset(dis,0x3f,sizeof dis);dis[1]=0;
	Q.push(make_pair(0,1));
	while(!Q.empty()){
		int x=Q.top().second;Q.pop();
		if(x<=n){
			for(auto o:vec[x])jian(x,o.t_t,o.ql,o.qr,o.qu,o.qd,1+n>>1);
		}
		else {
			if(ls[dy[x-n]]&&dis[a[ls[dy[x-n]]].id+n]>dis[x])dis[a[ls[dy[x-n]]].id+n]=dis[x],Q.push(make_pair(-dis[a[ls[dy[x-n]]].id+n],a[ls[dy[x-n]]].id+n));
			if(rs[dy[x-n]]&&dis[a[rs[dy[x-n]]].id+n]>dis[x])dis[a[rs[dy[x-n]]].id+n]=dis[x],Q.push(make_pair(-dis[a[ls[dy[x-n]]].id+n],a[rs[dy[x-n]]].id+n));
			if(dis[x-n]>dis[x])dis[x-n]=dis[x],Q.push(make_pair(-dis[x-n],x-n));
		}
	}
	for(int i=2;i<=n;i++)print(dis[i]),putchar('\n');
	return 0;
}
posted @ 2023-02-24 14:03  pengyule  阅读(31)  评论(0编辑  收藏  举报