[模板] K-D Tree
K-D Tree
K-D Tree可以看作二叉搜索树的高维推广, 它的第 \(k\) 层以所有点的第 \(k\) 维作为关键字对点做出划分.
为了保证划分均匀, 可以以第 \(k\) 维排名在中间的节点为划分节点. 这可以利用 std::nth_element
实现.
K-D Tree 支持单点修改. 为了保证 K-D Tree 的平衡性, 可以利用替罪羊树的思想, 在某个子树不平衡时重构这个子树.
同时, 对于每个点可以代表它所有子节点的 \([min(x_i), max(x_i)]\) 的一块超空间. 因此可以实现区间查询的操作.
根据 Wikipedia 的说法, 区间查询的最坏复杂度为单次 \(O(k \cdot n^{1-\frac 1k})\). (不会证)
其他操作
//to update
Code
//kdt
const int dk=2;
const db alp=0.75;
struct tp{
int v[dk];
int& operator[](int p){return v[p];}
const int& operator[](int p)const{return v[p];}
};
typedef const tp& ctp;
int key;
bool cmp1(ctp a,ctp b){return a[key]<b[key];}
bool eq(ctp a,ctp b){
rep(i,0,dk-1)if(a[i]!=b[i])return 0;
return 1;
}
struct tnd{tp po,mi,mx;int v,sum,sz,ch[2];}tree[nsz];
#define ls(p) tree[p].ch[0]
#define rs(p) tree[p].ch[1]
int rt=0,pt=0;
bool isbad(int p){return tree[ls(p)].sz>tree[p].sz*alp||tree[rs(p)].sz>tree[p].sz*alp;}
void pu(int p){
int l=ls(p),r=rs(p);
tree[p].sum=tree[l].sum+tree[r].sum+tree[p].v;
tree[p].sz=tree[l].sz+tree[r].sz+1;
rep(i,0,dk-1){
tree[p].mi[i]=min(tree[p].po[i],min((l?tree[l].mi[i]:(int)1e9),(r?tree[r].mi[i]:(int)1e9)));
tree[p].mx[i]=max(tree[p].po[i],max((l?tree[l].mx[i]:-1),(r?tree[r].mx[i]:-1)));
}
}
int li[nsz],pl=0;
bool cmp2(int a,int b){return cmp1(tree[a].po,tree[b].po);}
void pia(int rt){
if(ls(rt))pia(ls(rt));
li[++pl]=rt;
if(rs(rt))pia(rs(rt));
}
void build(int &rt,int rl,int rr,int k){
if(rl>rr){rt=0;return;}
int mid=(rl+rr)>>1;
key=k,nth_element(li+rl,li+mid,li+rr+1,cmp2);
rt=li[mid];
build(ls(rt),rl,mid-1,(k+1)%dk);
build(rs(rt),mid+1,rr,(k+1)%dk);
pu(rt);
}
void rebuild(int &rt,int k){
pl=0,pia(rt);
build(rt,1,pl,k);
}
void insert(tp p,int v,int &rt,int k){
if(rt==0){rt=++pt,tree[rt].po=p,tree[rt].v=v,pu(rt);return;}
if(eq(tree[rt].po,p)){tree[rt].v+=v,tree[rt].sum+=v;return;}
if(p[k]<=tree[rt].po[k])insert(p,v,ls(rt),(k+1)%dk);
else insert(p,v,rs(rt),(k+1)%dk);
pu(rt);
if(isbad(rt))rebuild(rt,k);
}
bool in(tp a1,tp a2,tp b1,tp b2){//(a1,a2) in (b1,b2)
rep(i,0,dk-1){
if(a1[i]<b1[i]||a2[i]>b2[i])return 0;
}
return 1;
}
bool out(tp a1,tp a2,tp b1,tp b2){//(a1,a2) completely out of (b1,b2)
rep(i,0,dk-1){
if(a2[i]<b1[i]||a1[i]>b2[i])return 1;
}
return 0;
}
int qu(tp a1,tp a2,int rt){
if(rt==0||out(a1,a2,tree[rt].mi,tree[rt].mx))return 0;
if(in(tree[rt].mi,tree[rt].mx,a1,a2))return tree[rt].sum;
int res=0;
if(in(tree[rt].po,tree[rt].po,a1,a2))res+=tree[rt].v;
res+=qu(a1,a2,ls(rt))+qu(a1,a2,rs(rt));
return res;
}