K-D Tree学习笔记

用途

  • 做各种二维三维四维偏序等等。
  • 代替空间巨大的树套树。
  • 数据较弱的时候水分。

思想

我们发现平衡树这种东西功能强大,然而只能做一维上的询问修改,显得美中不足。

于是我们尝试用平衡树的这种二叉树结构,做更高维的事情。

继续沿用平衡树的左儿子比自己小、右儿子比自己大的形态。这时发现,如果小于号定义得不好,那么做高维询问的时候就很难做。

发明者想到了这样一个方法:我们每过一层就划分一次超矩形。

具体地,给每一层一个\(type\),表示这一层是按哪一维切割。切割某一维的时候,拿出中位数,然后分成两边,于是就完成了小于号的定义。

给一张图:

查询的时候,类似线段树,只要完全包含了就返回,否则无脑往左右儿子跑。

如果要动态加入点怎么办?使用替罪羊树的思想,如果太过不平衡就拍扁重构。

时间复杂度分析?建树显然是\(O(n\log n)\)的,插入据说是\(O(n\log^2 n)\),而询问……引用一下AKMer的证明:

实现

以下代码基本没有参考别人的代码,如果有错误请指正。

为了方便,以下均用二维举例子,拓展到高维不会太难。

(其实也是因为我只写过二维)

定义

定义就是最简单的定义。

struct Point{int x,y,w;}p[sz];
int D;
inline bool cmp(const Point &x,const Point &y){return D?x.y<y.y:x.x<y.x;}
const db alpha=0.75;
int ch[sz<<3][2],fa[sz<<3],size[sz<<3],sum[sz<<3],type[sz<<3]; // 数组大小我也不知道要开几倍,反正O(n)的多开一点应该也没事
Point P[sz<<3];
int L[sz<<3][2],R[sz<<3][2]; // 每一维的上下限
#define ls ch[x][0]
#define rs ch[x][1]
int root,cnt;
int st[sz],top;
void erase(int x){st[++top]=x;ls=rs=size[x]=sum[x]=L[x][0]=R[x][0]=L[x][1]=R[x][1]=0;P[x]=(Point){0,0,0};}
int newnode(){return top?st[top--]:++cnt;} // k-d tree经常重构,所以最好垃圾回收

拍扁重构

利用STL中的nth_element函数,可以做到\(O(n\log n)\)建树。

为了切割得较为均匀,这里选择每一维轮换切割。还有另一种方法是切方差最大的一维,但是懒得写了。

Point pp[sz];int m;
int build(int l,int r,int f)
{
	if (l>r) return 0;
	int x=newnode();fa[x]=f;
	type[x]=type[f]^1;
	int mid=(l+r)>>1;
	D=type[x];nth_element(pp+l,pp+mid,pp+r+1,cmp);
	P[x]=pp[mid];
	L[x][0]=R[x][0]=P[x].x,L[x][1]=R[x][1]=P[x].y;
	ls=build(l,mid-1,x),rs=build(mid+1,r,x);
	size[x]=size[ls]+size[rs]+1;sum[x]=sum[ls]+sum[rs]+P[x].w;
	rep(i,0,1) if (ch[x][i]) rep(k,0,1) chkmin(L[x][k],L[ch[x][i]][k]),chkmax(R[x][k],R[ch[x][i]][k]);
	return x;
}

插入

插入时要判是否平衡,如果不平衡就擦除一整棵子树并重构。

void Erase(int x){if (!x) return;pp[++m]=P[x];Erase(ls),Erase(rs);erase(x);}
void insert(Point p)
{
	int top=-1,x=root;
	if (!x) return (void)(pp[1]=p,root=build(1,1,0));
	while (233)
	{
		if (max(size[ls],size[rs])>size[x]*alpha&&top==-1) top=x;
		++size[x];sum[x]+=p.w;
		chkmin(L[x][0],p.x),chkmax(R[x][0],p.x);
		chkmin(L[x][1],p.y),chkmax(R[x][1],p.y);
		D=type[x];int &y=ch[x][!cmp(p,P[x])];
		if (!y)
		{
			y=newnode();
			L[y][0]=R[y][0]=p.x;
			L[y][1]=R[y][1]=p.y;
			size[y]=1;sum[y]=p.w;type[y]=type[x]^1;fa[y]=x;
			P[y]=p;
			break;
		}
		x=y;
	}
	if (top==-1) return;
	m=0;
	if (top==root) { Erase(top); root=build(1,m,0); return; }
	int f=fa[top],&t=ch[f][ch[f][1]==top];
	Erase(top);
	t=build(1,m,f);
}

询问

无脑做就好了。

int query(int x,int l0,int r0,int l1,int r1)
{
	if (!x) return 0;
	if (l0<=L[x][0]&&R[x][0]<=r0&&l1<=L[x][1]&&R[x][1]<=r1) return sum[x];
	if (r0<L[x][0]||R[x][0]<l0||r1<L[x][1]||R[x][1]<l1) return 0;
	return query(ls,l0,r0,l1,r1)+query(rs,l0,r0,l1,r1)+(l0<=P[x].x&&P[x].x<=r0&&l1<=P[x].y&&P[x].y<=r1?P[x].w:0);
}

例题

LOJ112 三维偏序

LOJ3159「NOI2019」弹跳

目前我也只知道这两个了,毕竟我自己也是刚学qwq

posted @ 2019-07-26 21:15  p_b_p_b  阅读(278)  评论(0编辑  收藏  举报