K-D 树学习笔记
K-D树用于解决2维(或高维)平面上的一切线段树能解决的问题。
K-D树上的每个节点代表一个实际的点,而非一个区间。
K-D树的建树:以二维平面为例,K-D树首先选取当前点集中一个方差最大的维度,并在选出这个维度坐标为中位数的点,用 max_element(a+l,a+mid,a+r+1,cmp)
将它放到mid上,小于它的扔到左边,大于的放到右边,递归去做。
2-D树的查询复杂度是\(O(\sqrt n)\)的,层数为\(\log\)。
支持修改的K-D树需要在 pushup
后判断当前点的左右子树大小是否平衡,适时进行暴力重构。
判断平衡的方法是令 \(α=0.725\),若 \(\max(siz[ls],siz[rs])>siz[x]·α\) 则调用rebuild(x)
。
rebuild
的方法为:首先把x的子树内的所有点取出来,然后 build
。
K-D树常能通过维护剪枝所需要的信息来辅助搜索,骗到可观的分数。
K-D树优化建图等与线段树类似的算法,它也可以完成。
代码模板:(P4148 简单题)
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
inline int read(){
register char ch=getchar();register int x=0,f=1;
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
return x*f;
}
void print(int x){
if(x<0){putchar('-'),print(-x);return;}
if(x/10)print(x/10);
putchar(x%10+48);
}
const int N=2e5+5;
int n,las,tot,num,ls[N],rs[N],a[N],siz[N],L[N],R[N],U[N],D[N],type[N],sum[N];
double alpha=0.725;
struct node {
int x,y,v;
}b[N];
void pushup(int x){
siz[x]=siz[ls[x]]+siz[rs[x]]+1,sum[x]=b[x].v,L[x]=R[x]=b[x].x,U[x]=D[x]=b[x].y;
if(ls[x])sum[x]+=sum[ls[x]],L[x]=min(L[x],L[ls[x]]),R[x]=max(R[x],R[ls[x]]),U[x]=max(U[x],U[ls[x]]),D[x]=min(D[x],D[ls[x]]);
if(rs[x])sum[x]+=sum[rs[x]],L[x]=min(L[x],L[rs[x]]),R[x]=max(R[x],R[rs[x]]),U[x]=max(U[x],U[rs[x]]),D[x]=min(D[x],D[rs[x]]);
}bool cmp1(int i, int j) { return b[i].x < b[j].x; }
bool cmp2(int i, int j) { return b[i].y < b[j].y; }
int build(int l,int r){
if(l>r)return 0;
int mid=l+r>>1;
double av1=0,av2=0,va1=0,va2=0;
for(int i=l;i<=r;i++)av1+=b[a[i]].x,av2+=b[a[i]].y;
av1/=r-l+1,av2/=r-l+1;
for(int i=l;i<=r;i++)va1+=(b[a[i]].x-av1)*(b[a[i]].x-av1),va2+=(b[a[i]].y-av2)*(b[a[i]].y-av2);
if(va1>va2)nth_element(a+l,a+mid,a+r+1,cmp1),type[a[mid]]=1;
else nth_element(a+l,a+mid,a+r+1,cmp2),type[a[mid]]=2;
ls[a[mid]]=build(l,mid-1),rs[a[mid]]=build(mid+1,r);
pushup(a[mid]);
return a[mid];
}
void dfs(int x){//改成先序遍历也可以吧?
if(ls[x])dfs(ls[x]);
a[++tot]=x;
if(rs[x])dfs(rs[x]);
}
void rebuild(int &x){
tot=0,dfs(x),x=build(1,tot);
}
int X1,X2,Y1,Y2,ans;
void ask(int x){
if(!x)return;
if(L[x]>X2||R[x]<X1||D[x]>Y2||U[x]<Y1)return;
if(L[x]>=X1&&R[x]<=X2&&D[x]>=Y1&&U[x]<=Y2){
ans+=sum[x];
return;
}
if(b[x].x>=X1&&b[x].x<=X2&&b[x].y>=Y1&&b[x].y<=Y2)ans+=b[x].v;
ask(ls[x]),ask(rs[x]);
}
int vx,vy,vv;
void insert(int &x){
if(!x){
x=++num,b[x].x=vx,b[x].y=vy,b[x].v=vv;
pushup(x);
return;
}
if(type[x]==1){
if(vx<=b[x].x)insert(ls[x]);
else insert(rs[x]);
}
else {
if(vy<=b[x].y)insert(ls[x]);
else insert(rs[x]);
}
pushup(x);
if(siz[x]*alpha<=max(siz[ls[x]],siz[rs[x]]))rebuild(x);
}
int main(){
n=read();
int op,Rt=0;
while(~scanf("%d",&op)){
if(op==3)return 0;
if(op==1){
scanf("%d%d%d",&vx,&vy,&vv);
vx^=las,vy^=las,vv^=las;
insert(Rt);
}
else {
scanf("%d%d%d%d",&X1,&Y1,&X2,&Y2);
X1^=las,Y1^=las,X2^=las,Y2^=las;
ans=0,ask(Rt);
cout<<(las=ans)<<'\n';
}
}
return 0;
}
经典例题:[NOI2019]弹跳
学完K-D树后,K-D树优化建图的思路就是显然的,需要注意的是每个点有K-D树节点一个,和自己一个,两个点分开连边,其中K-D树上的节点沿K-D树边连边,并连向对应的自己点,而每个自己点连向其对应弹跳装置对应的二维区间对应的K-D树上若干节点(这些点中有的是“节点”有的是“自己点”,因为是K-D树不是线段树的缘故,划分出的有整块也有散点),88分到手,MLE。
其实完全没必要提前把图建出来,而是到了一个点的时候,做同样的事情即可。由于使用的是dij,每个点只会取出一次,因此总复杂度是 \(O(m\sqrt n \log n)\) 的。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
inline int read(){
register char ch=getchar();register int x=0,f=1;
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
return x*f;
}
void print(int x){
if(x<0){putchar('-'),print(-x);return;}
if(x/10)print(x/10);
putchar(x%10+48);
}
const int N=1.4e5+5;
int n,m,W,H,L[N],R[N],U[N],D[N],ls[N],rs[N],dis[N],dy[N];
priority_queue<pair<int,int> >Q;
struct node {
int x,y,id;
}a[N];
void pushup(int x){
L[x]=R[x]=a[x].x,U[x]=D[x]=a[x].y;
if(ls[x])L[x]=min(L[x],L[ls[x]]),R[x]=max(R[x],R[ls[x]]),U[x]=max(U[x],U[ls[x]]),D[x]=min(D[x],D[ls[x]]);
if(rs[x])L[x]=min(L[x],L[rs[x]]),R[x]=max(R[x],R[rs[x]]),U[x]=max(U[x],U[rs[x]]),D[x]=min(D[x],D[rs[x]]);
}
int build(int l,int r){
if(l>r)return 0;
int mid=l+r>>1;
double av1=0,av2=0,va1=0,va2=0;
for(int i=l;i<=r;i++)av1+=a[i].x,av2+=a[i].y;
av1/=r-l+1,av2/=r-l+1;
for(int i=l;i<=r;i++)va1+=(a[i].x-av1)*(a[i].x-av1),va2+=(a[i].y-av2)*(a[i].y-av2);
if(va1>va2)nth_element(a+l,a+mid,a+r+1,[](node a,node b){return a.x<b.x;});
else nth_element(a+l,a+mid,a+r+1,[](node a,node b){return a.y<b.y;});
ls[mid]=build(l,mid-1),rs[mid]=build(mid+1,r);
pushup(mid);
dy[a[mid].id]=mid;
return mid;
}
struct Edge{
int t_t,ql,qr,qu,qd;
};
vector<Edge>vec[N];
void jian(int tar,int t_t,int ql,int qr,int qu,int qd,int x){
if(!x)return;
if(ql>R[x]||qr<L[x]||qu<D[x]||qd>U[x])return;
if(dis[a[x].id+n]<dis[tar]+t_t)return;
if(ql<=L[x]&&R[x]<=qr&&qd<=D[x]&&U[x]<=qu){
if(dis[a[x].id+n]>dis[tar]+t_t){
dis[a[x].id+n]=dis[tar]+t_t;
Q.push(make_pair(-dis[a[x].id+n],a[x].id+n));
}
return;
}
if(ql<=a[x].x&&a[x].x<=qr&&qd<=a[x].y&&a[x].y<=qu){
if(dis[a[x].id]>dis[tar]+t_t){
dis[a[x].id]=dis[tar]+t_t;
Q.push(make_pair(-dis[a[x].id],a[x].id));
}
}
jian(tar,t_t,ql,qr,qu,qd,ls[x]),jian(tar,t_t,ql,qr,qu,qd,rs[x]);
}
int main(){
n=read(),m=read(),W=read(),H=read();
for(int i=1;i<=n;i++)a[i].x=read(),a[i].y=read(),a[i].id=i;
build(1,n);
for(int i=1,tar,t_t,ql,qr,qd,qu;i<=m;i++){
tar=read(),t_t=read(),ql=read(),qr=read(),qd=read(),qu=read();
vec[tar].emplace_back(Edge{t_t,ql,qr,qu,qd});
}
memset(dis,0x3f,sizeof dis);dis[1]=0;
Q.push(make_pair(0,1));
while(!Q.empty()){
int x=Q.top().second;Q.pop();
if(x<=n){
for(auto o:vec[x])jian(x,o.t_t,o.ql,o.qr,o.qu,o.qd,1+n>>1);
}
else {
if(ls[dy[x-n]]&&dis[a[ls[dy[x-n]]].id+n]>dis[x])dis[a[ls[dy[x-n]]].id+n]=dis[x],Q.push(make_pair(-dis[a[ls[dy[x-n]]].id+n],a[ls[dy[x-n]]].id+n));
if(rs[dy[x-n]]&&dis[a[rs[dy[x-n]]].id+n]>dis[x])dis[a[rs[dy[x-n]]].id+n]=dis[x],Q.push(make_pair(-dis[a[ls[dy[x-n]]].id+n],a[rs[dy[x-n]]].id+n));
if(dis[x-n]>dis[x])dis[x-n]=dis[x],Q.push(make_pair(-dis[x-n],x-n));
}
}
for(int i=2;i<=n;i++)print(dis[i]),putchar('\n');
return 0;
}