k-d tree 学习笔记
以下是一些奇怪的链接有兴趣的可以看看:
https://blog.sengxian.com/algorithms/k-dimensional-tree
http://zgjkt.blog.uoj.ac/blog/1693
https://en.wikipedia.org/wiki/K-d_tree
http://homes.ieu.edu.tr/hakcan/projects/kdtree/kdTree.html
k-d tree就是一个把一个平面(或超平面)划分的东西…
例如一维情况就是在划分一个数轴,这时就成了我们喜闻乐见的二叉搜索树。
二维情况
用横线和竖线,每次把一个矩形瓜分成两半。
还有三维的,四维的,以此类推…… k维的就叫 k-d tree。
这篇文章就讲一下二维的k-d tree。
每一次我们对于k-d tree上的一个点,我们要沿着某一个轴切开,那么这个轴怎么确定呢?一个切x轴的点的孩子就切y轴,切y轴的点的孩子就切x轴,轮流切分,这样做比较优秀,又比较好写。
例如这个点是切x轴的,为了查询比较和谐,那么它的左子树所有点都要x比它小,右子树所有点都要x比它大(当然也可以是等于),这点与二叉搜索树是一致的)。
①建树
比如我现在有n个点,现在要建成一棵美观的k-d tree,怎么建呢?
在bst中我们一般是选比较中间的一个数,然后递归下去建树。
k-d tree中也是类似的,不过我们要确保当前节点的左子树的点当前维度都比它小,右子树的点当前维度都比它大。我们可以每一次在当前层sort一下当前维度,然后选出中位数?
注意到我们只要让左边比较小,右边比较大,中间比较中间就行了,sort一下显然有点浪费,到了下一维度又要重新sort。这里我们使用stl中的nth_element函数,该函数内部用了quick_select算法(就是快速排序的前半部分)来干这件事,复杂度O(n)。
void nth_element(first,nth,last[,cmp]) (比较函数可选)
对[first,last)的元素重新布置,使得nth的元素在sort之后应该在的位置上,而前面的元素都不大于它,后面的元素都不小于它
这样我们就可以建出一棵比较平衡的k-d tree啦。参考线段树建树的复杂度证明很容易发现这玩意儿也是O(nlogn)的。
②插入
额似乎这玩意儿没必要讲,也就是和当前点的当前维度比较一下然后插入就行。
对于一棵比较平衡的k-d tree,插入显然是O(logn)的。
③询问
首先我们要解决的是…k-d tree有卵用?
显然k-d tree可以被当作是一个比较厉害的k维(本文中是二维)线段树!在这上面可以打tag,可以搞询问(划分矩形的时候注意一下),甚至还可以搞segment tree beats(参见吉司机的集训队论文)
但是二维情况下k-d tree当作线段树用处不大…因为前有二维树状数组/二维线段树,后有cdq分治、整体二分啥的…除非出题人丧心病狂又卡空间又要在线
正常的k-d tree是用来询问最近点的利器!
比如有一道这样的题(bzoj2648),要求每次询问离一个给定点最近(曼哈顿距离)的黑点和加入一个黑点。
真是良心出题人!
查找时我们可以这么做,先把ans设成inf,然后我们从根结点开始递归查询。
对于k-d tree上的一个节点,先用这个点更新答案。然后对于它的两个孩子,优先选择离当前点近的一个矩形(注意这个值显然不一定能取到)递归更新答案,然后更新完答案如果到另一个儿子的那个矩形距离比答案小就也用另一个儿子递归更新答案。
对于一棵比较平衡的k-d tree,这样查询的复杂度最坏是O(sqrt(n))的,一般是O(log(n))的。
(这个“一般”的性质就跟spfa所谓的O(ke)差不多)
可以注意到这里的“近”“距离”对于欧几里得距离与曼哈顿距离都适用。
至于k近点对也比较简单,我们维护一个大根堆,每次比较堆顶啥的就行。
显然这样的复杂度比刚才要多一个klogk。
需要注意的是,你可能需要把k*=2(因为一个点对会算两次)
④一些小问题
有没有注意到刚才的复杂度上有若干个类似“比较平衡”的字眼?
是的,k-d tree的性质就与一棵没有旋转的二叉查找树差不多。
如果建完树不再插入,这时k-d tree是十分平衡的,可以保证查询的一切复杂度。
如果建完树还插入或者干脆不建树直接插入,这时k-d tree对于随机数据仍然是“十分平衡”的,而随便构造一个数据(例如1,1 2,2 3,3...这样的数据)就可以成功地卡飞k-d tree。
所以如果又有插入又有询问,你只能假装数据是随机的...
代码(常数非常大,有一些题目可能无法通过...原因玄学)
bzoj2716(权限,时限80s才勉强艹过去)
#include <iostream> #include <stdio.h> #include <stdlib.h> #include <algorithm> using namespace std; #define gc getchar() int g_i() { int tmp=0; bool fu=0; char s; while(s=gc,s!='-'&&(s<'0'||s>'9')) ; if(s=='-') fu=1; else tmp=s-'0'; while(s=gc,s>='0'&&s<='9') tmp=tmp*10+s-'0'; if(fu) return -tmp; else return tmp; } #define gi g_i() #define pob #define pc(x) (putchar(x)) struct foce {~foce() {pob; fflush(stdout);}} _foce; inline void pstr(char* p) {while(*p)pc(*(p++));} namespace ib {char b[100];} inline void pi(int x) { if(x==0) {pc(48); return;} if(x<0) {pc('-'); x=-x;} char *s=ib::b; while(x) *(++s)=x%10, x/=10; while(s!=ib::b) pc((*(s--))+48); } struct pnt { int _x,_y; pnt() {} pnt(int x,int y) {_x=x; _y=y;} }; int Abs(int x) {return (x>=0)?x:-x;} #define SZ 1234567 int dis(pnt a,pnt b) {return Abs(a._x-b._x)+Abs(a._y-b._y);} //kdtree bool dim_; bool cmp(pnt a,pnt b) { if(!dim_) return a._x<b._x||(a._x==b._x&&a._y<b._y); else return a._y<b._y||(a._y==b._y&&a._x<b._x); } int rot,M=0,ch[SZ][2],inf=1000000000; pnt pp[SZ],lu[SZ],rd[SZ]; inline int min_(int a,int b) { return (a<=b)?a:b; } inline int min_(int a,int b,int c) { int ans=a; if(b<ans) ans=b; if(c<ans) ans=c; return ans; } inline int max_(int a,int b,int c) { int ans=a; if(b>ans) ans=b; if(c>ans) ans=c; return ans; } void upd(int x) { lu[x]._x=min_(lu[ch[x][0]]._x,lu[ch[x][1]]._x,lu[x]._x); lu[x]._y=min_(lu[ch[x][0]]._y,lu[ch[x][1]]._y,lu[x]._y); rd[x]._x=max_(rd[ch[x][0]]._x,rd[ch[x][1]]._x,rd[x]._x); rd[x]._y=max_(rd[ch[x][0]]._y,rd[ch[x][1]]._y,rd[x]._y); } int dis(int x,pnt p) { int ans=0; if(p._x<lu[x]._x) ans+=lu[x]._x-p._x; else if(p._x>rd[x]._x) ans+=p._x-rd[x]._x; if(p._y<lu[x]._y) ans+=lu[x]._y-p._y; else ans+=p._y-rd[x]._y; return ans; } void init() { rot=M=0; lu[0]._x=inf; lu[0]._y=inf; rd[0]._x=-inf; rd[0]._y=-inf; } pnt ps[SZ]; int build(int l,int r,bool d=0) { if(l>=r) return 0; int c=++M,m=l+r>>1; dim_=d; nth_element(ps+l,ps+m,ps+r,cmp); pp[c]=lu[c]=rd[c]=ps[m]; ch[c][0]=build(l,m,!d); ch[c][1]=build(m+1,r,!d); upd(c); return c; } int ans; void query(int x,pnt p) { if(!x) return; ans=min_(ans,dis(p,pp[x])); int dc[2]={dis(ch[x][0],p),dis(ch[x][1],p)}; int d=dc[0]>dc[1]; query(ch[x][d],p); if(dc[!d]<ans) query(ch[x][!d],p); } void ins(int& x,pnt p) { if(!x) {x=++M; pp[x]=lu[x]=rd[x]=p; return;} int d=!cmp(p,pp[x]); dim_=!dim_; ins(ch[x][d],p); upd(x); } int n,m; int main() { init(); n=gi, m=gi; for(int i=0;i<n;i++) ps[i]._x=gi, ps[i]._y=gi; rot=build(0,n); while(m--) { int t=gi, x=gi, y=gi; if(t==1) {dim_=0; ins(rot,pnt(x,y));} else if(t==2) { ans=inf; query(rot,pnt(x,y)); pi(ans); pc(10); } } pob; }
bzoj4520
#include <iostream> #include <stdio.h> #include <stdlib.h> #include <algorithm> #include <queue> using namespace std; typedef long long ll; #define gc getchar() int g_i() { int tmp=0; bool fu=0; char s; while(s=gc,s!='-'&&(s<'0'||s>'9')) ; if(s=='-') fu=1; else tmp=s-'0'; while(s=gc,s>='0'&&s<='9') tmp=tmp*10+s-'0'; if(fu) return -tmp; else return tmp; } #define gi g_i() #define pob #define pc(x) (putchar(x)) struct foce {~foce() {pob; fflush(stdout);}} _foce; inline void pstr(char* p) {while(*p)pc(*(p++));} namespace ib {char b[100];} inline void pi(int x) { if(x==0) {pc(48); return;} if(x<0) {pc('-'); x=-x;} char *s=ib::b; while(x) *(++s)=x%10, x/=10; while(s!=ib::b) pc((*(s--))+48); } struct pnt { int _x,_y; pnt() {} pnt(int x,int y) {_x=x; _y=y;} }; int Abs(int x) {return (x>=0)?x:-x;} ll pf(ll x) {return x*x;} #define SZ 1234567 ll dis(pnt a,pnt b) {return pf(a._x-b._x)+pf(a._y-b._y);} //kdtree bool dim_; bool cmp(pnt a,pnt b) { if(!dim_) return a._x<b._x||(a._x==b._x&&a._y<b._y); else return a._y<b._y||(a._y==b._y&&a._x<b._x); } int rot,M=0,ch[SZ][2]; ll inf=2147483647; ll inf_ll=100000000000000000LL; pnt pp[SZ],lu[SZ],rd[SZ]; inline int min_(int a,int b) { return (a<=b)?a:b; } inline int min_(int a,int b,int c) { int ans=a; if(b<ans) ans=b; if(c<ans) ans=c; return ans; } inline int max_(int a,int b,int c) { int ans=a; if(b>ans) ans=b; if(c>ans) ans=c; return ans; } void upd(int x) { lu[x]._x=min_(lu[ch[x][0]]._x,lu[ch[x][1]]._x,lu[x]._x); lu[x]._y=min_(lu[ch[x][0]]._y,lu[ch[x][1]]._y,lu[x]._y); rd[x]._x=max_(rd[ch[x][0]]._x,rd[ch[x][1]]._x,rd[x]._x); rd[x]._y=max_(rd[ch[x][0]]._y,rd[ch[x][1]]._y,rd[x]._y); } ll dis(int x,pnt p) { ll ans=0; ans=max(ans,dis(rd[x],p)); ans=max(ans,dis(lu[x],p)); ans=max(ans,dis(pnt(rd[x]._x,lu[x]._y),p)); ans=max(ans,dis(pnt(lu[x]._x,rd[x]._y),p)); return ans; } void init() { rot=M=0; lu[0]._x=inf; lu[0]._y=inf; rd[0]._x=-inf; rd[0]._y=-inf; } pnt ps[SZ]; int build(int l,int r,bool d=0) { if(l>r) return 0; int c=++M,m=l+r>>1; dim_=d; nth_element(ps+l,ps+m,ps+1+r,cmp); pp[c]=lu[c]=rd[c]=ps[m]; ch[c][0]=build(l,m-1,!d); ch[c][1]=build(m+1,r,!d); upd(c); return c; } typedef priority_queue<ll,vector<ll>,greater<ll> >sheap; sheap pq; void query(int x,pnt p) { if(!x) return; ll d1=dis(p,pp[x]); if(d1>=pq.top()) pq.pop(), pq.push(d1); ll dc[2]={dis(ch[x][0],p),dis(ch[x][1],p)}; int d=dc[0]<dc[1]; query(ch[x][d],p); if(dc[!d]>=pq.top()) query(ch[x][!d],p); } int n,m; int main() { init(); n=gi, m=gi; for(int i=1;i<=n;i++) ps[i]._x=gi, ps[i]._y=gi; rot=build(1,n); for(int i=1;i<=m*2;i++) pq.push(-inf_ll); for(int i=1;i<=n;i++) query(rot,ps[i]); printf("%lld\n",pq.top()); }
P.S. 常数比较优秀的代码在 https://blog.sengxian.com/algorithms/k-dimensional-tree 这里有...指针版