【学习小记】KD-Tree

Preface

听说KD树实在是个大神器
可以解决多维空间多维偏序点权和,可以求某个点的空间最近最远点

就二维平面上的来说,复杂度在\(O(n\log n)\)\(O(n\sqrt n)\)不等

嫌KD树不平衡了还可以来一个替罪羊树式的暴力重构
再也不用担心写不出树套树了!(狗头)

Text

这是个什么东西呢?

在一维数轴上,可以简单的比较两个值的大小,我们就很容易将它们分治,建出一个二叉搜索树
但是拓展到高维往往就很困难

KD树解决了这一麻烦,KD树也是一个二叉分治结构,不同的是KD树的分治维是会变的。

比如说对于1号节点,它是按照X坐标分治的,即它的左子树的X坐标比他小,Y坐标比他大
对于另一个2号节点,他可能又按照Y坐标分治了。

可以根据以下构树的过程来加深理解。

构造

以二维平面为例,我们现在有一堆点。

按什么坐标分治其实是可以随机的,而另一种比较好的方法是每一维轮流,父亲按X分治,当前就按Y分治这样子下去

实现上来说,我们当前的节点可以表示成节点序列上的一段区间[l,r]。
现在我们想要知道节点序列上区间[l,r]中第D维的中位数,然后把比它小的放到它左边,大的放到右边。

这个东西排序似乎要\(n\log\)
然而c++STL库提供了一个std::nth_element()的函数,找出序列第k大,并且把小的放左边,大的放右边,内部的相对顺序不保证,并且它的复杂度是线性的。(原理好像是基于分治的)
它一般可以传四个参,依次是序列头,序列第k个位置,序列维,比较函数
调用完以后序列第k个位置就已经放好了。

这样就搞定了,递归建树即可,注意如果是左边或右边没有元素了就不递归那边。

显然,平衡树能记录的值它都能记录。子树和,子树最大值等等

为了方便,我们一般还记录子树中每一维坐标的范围。

可以看出,每个子树就是一个矩形内的所有点。

动态插入

类似splay一样向下搜索到叶子,直接加入即可。

然而多次以后可能会出现不平衡的情况。(即某一侧子树过大)
那么我们可以类似替罪羊树的做法,插入以后找最高的一个不平衡点,对它的子树进行重构。

当然,由于KD树的本质是一个暴力,你完全可以每插入7000次或者10000次之类的把整棵树重构一遍(好写简洁,慢不了多少)

动态删除

类似替罪羊树,删除一个节点时,找它左子树中这一维最大的那个节点替代它。
不平衡就重构。

基本操作/询问

矩形加/矩形赋值/矩形查询,类似线段树那样做就好了。
对于K维操作,n个点,总的复杂度大概是\(O(kn^{1-{1\over k}})\)(当然1维不是)

有一个它能做而树套树完成不了的问题——K维最近/最远点
这个距离可以是曼哈顿距离,欧几里得距离,甚至切比雪夫距离。

它的本质就是一个搜索。
由于我们知道了每个子树的每维坐标的范围,我们就可以算出询问节点到这个子树的矩形的距离至少(至多)是多少

以欧几里得最近点为例,
令矩形为\((x1,y1,x2,y2)\),询问点为\((p,q)\),那么这个点到这个矩形内的点距离至少是\(\sqrt{max(0,max(x1-p,p-x2))^2+max(0,max(y1-q,q-y2))^2}\)

简单来说,这就是一个估价函数。
如果估价函数值比当前已经搜出来的答案还大,那么直接退出。

随机数据下复杂度是\(O(n\log n)\)的,构造下大概是\(O(n\sqrt n)\)
k维复杂度大概是\(O(kn^{1-{1\over k}})\)

这可以用来优化某些DP

先说这么多,如果有别的应用以后再补。

Code

一个相当丑的模板

namespace KDT
{
	int D,rt,t[N][2],sz[N],dw[N],tr[N],le;
	int a[N][3],rg[N][3][2],ans;
	bool cmp(int x,int y)
	{
		return a[x][D]<a[y][D];
	}
	void upd(const int &k,const int &y)
	{
		if(y) 
		{
			fo(j,0,2) 
			{
				rg[k][j][0]=min(rg[k][j][0],rg[y][j][0]);
				rg[k][j][1]=max(rg[k][j][1],rg[y][j][1]);
			}
			sz[k]=sz[k]+sz[y];
		}
	}
	void up(int k)
	{
		sz[k]=1;fo(j,0,2) rg[k][j][0]=rg[k][j][1]=a[k][j];
		if(t[k][0]) upd(k,t[k][0]);
		if(t[k][1]) upd(k,t[k][1]);
	}
	int make(int l,int r)
	{
		int mid=(l+r)>>1;
		D=(D+1)%3;
		std::nth_element(tr+l,tr+mid,tr+r+1,cmp);
		int k=tr[mid];dw[k]=D;
		t[k][0]=(l<mid)?make(l,mid-1):0;
		t[k][1]=(r>mid)?make(mid+1,r):0;
		up(k);
		return k;
	}
	void ins(int &k,int p)
	{	
		if(!k) 
		{
			k=p;sz[p]=1;dw[p]=(D+1)%3;
			fo(j,0,2) rg[p][j][0]=rg[p][j][1]=a[p][j];
			return;
		}
		D=dw[k];
		if(a[p][D]<=a[k][D]) ins(t[k][0],p);
		else ins(t[k][1],p);
		up(k);
	}
	int pri(int k,int x,int y,int z)
	{
		return rg[k][0][0]+sqr(max(0,max(rg[k][1][0]-y,y-rg[k][1][1])))+sqr(max(0,max(rg[k][2][0]-z,z-rg[k][2][1])));
	}
	int calc(int k,int x,int y,int z)
	{
		return a[k][0]+sqr(a[k][1]-y)+sqr(a[k][2]-z);
	}
	void query(int k,int x,int y,int z)
	{
		if(pri(k,x,y,z)>=ans) return;
		ans=min(ans,calc(k,x,y,z));
		if(t[k][0]) query(t[k][0],x,y,z);
		if(t[k][1]) query(t[k][1],x,y,z);
	}
}
posted @ 2019-03-21 22:02  BAJim_H  阅读(334)  评论(0编辑  收藏  举报