二维K-D Tree:

题目地址:https://www.luogu.org/problemnew/show/P4169

#include<bits/stdc++.h>
#define alph (0.75)
using namespace std;
const int N=1000005;
struct point{
	int x[2];
}p[N];
struct node{
	int mi[2],mx[2],ls,rs,sz;
	point tp;
}tr[N];
int n,m,rt,cur,top,WD,ans,rub[N];
int operator <(point a,point b){
	return a.x[WD]<b.x[WD];
}
int newnode(){//建立新节点
    if(top)return rub[top--];
    else return ++cur;
}
void updata(int k) {
    int l=tr[k].ls,r=tr[k].rs;//结点k左右子结点的指针 
    for(int i=0;i<=1;++i){//二维 
        tr[k].mi[i]=tr[k].mx[i]=tr[k].tp.x[i];//初始化k点所代表矩阵的mimx 
        //如果有左结点 尝试用左结点的最小/大值更新 
		if(l)
			tr[k].mi[i]=min(tr[k].mi[i],tr[l].mi[i]),tr[k].mx[i]=max(tr[k].mx[i],tr[l].mx[i]);
        //如果有右结点 尝试用右结点的最小/大值更新 
		if(r)
			tr[k].mi[i]=min(tr[k].mi[i],tr[r].mi[i]),tr[k].mx[i]=max(tr[k].mx[i],tr[r].mx[i]);
    }
    //更新结点k所代表区域的点数
    tr[k].sz=tr[l].sz+tr[r].sz+1;
}
int build(int l,int r,int wd){//建树 
	//越界判断 
    if(l>r)return 0;
    //初始化新节结点和中点位置 
    int k=newnode(),mid=(l+r)>>1;
    //WD:当前所在维度 nth_element:使位置p+mid是该序列的第x大数
    WD=wd,nth_element(p+l,p+mid,p+r+1),tr[k].tp=p[mid];
    //建左子树和右子树 
    tr[k].ls=build(l,mid-1,wd^1),tr[k].rs=build(mid+1,r,wd^1);
    //更新结点数据 
    updata(k);
	return k;
}
void pia(int k,int num) {//拍扁
	//如果有左结点 拍扁左结点 
    if(tr[k].ls)pia(tr[k].ls,num);
    //?
    p[num+tr[tr[k].ls].sz+1]=tr[k].tp,rub[++top]=k;
    //如果有右结点 拍扁左结点 
    if(tr[k].rs)pia(tr[k].rs,num+tr[tr[k].ls].sz+1);
}
void check(int &k,int wd){//检查子树是否不平衡
    if(alph*tr[k].sz<tr[tr[k].ls].sz||alph*tr[k].sz<tr[tr[k].rs].sz)
        pia(k,0),k=build(1,tr[k].sz,wd);
}
void ins(point tmp,int &k,int wd){//插入
    if(!k) {k=newnode(),tr[k].tp=tmp,tr[k].ls=tr[k].rs=0,updata(k);return;}
    if(tr[k].tp.x[wd]<tmp.x[wd]) ins(tmp,tr[k].rs,wd^1);
    else ins(tmp,tr[k].ls,wd^1);
    updata(k),check(k,wd);//记得在check之前要先pushupdata
}
int getdis(point tmp,int k){//获得当前点到矩形的曼哈顿距离
    int re=0;
    for(int i=0;i<=1;++i)
        re+=max(0,tmp.x[i]-tr[k].mx[i])+max(0,tr[k].mi[i]-tmp.x[i]);
    return re;
}
//两点之间的距离(x1-x2)+(y1-y2) 
int dist(point a,point b){return abs(a.x[0]-b.x[0])+abs(a.x[1]-b.x[1]);}
void query(point tmp,int k) {//查询最近点 k:指针 
	//更新ans取较小值 
    ans=min(ans,dist(tmp,tr[k].tp));
    int dl=INT_MAX,dr=INT_MAX;//初始化l r到该点的距离 
    if(tr[k].ls)dl=getdis(tmp,tr[k].ls);//如果有左结点 dl=点到左子树最近距离 
    if(tr[k].rs)dr=getdis(tmp,tr[k].rs);//如果有右结点 dr=点到右子树最近距离 
	if(dl<dr){//贪心剪枝 
        if(dl<ans)query(tmp,tr[k].ls);//优先在左子树查找 
        if(dr<ans)query(tmp,tr[k].rs);
    }else{
        if(dr<ans)query(tmp,tr[k].rs);
        if(dl<ans)query(tmp,tr[k].ls);
    }
}
int main() {
    int bj;
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;++i)scanf("%d%d",&p[i].x[0],&p[i].x[1]);
    rt=build(1,n,0);
    while(m--) {
        point tmp;
        scanf("%d%d%d",&bj,&tmp.x[0],&tmp.x[1]);
        if(bj==1)ins(tmp,rt,0);
        else ans=INT_MAX,query(tmp,rt),printf("%d\n",ans);
    }
    return 0;
}

四维K-D Tree:

题目地址:https://www.luogu.org/problemnew/show/P3769


#include<iostream>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<algorithm>

using namespace std;

const int N=50009;

int main_d;

inline int minn(int a,int b){if(a<b)return a;return b;}
inline int maxx(int a,int b){if(a>b)return a;return b;}
inline int read()
{
    int x=0,f=1;char ch=getchar();
    while(ch<'0' || '9'<ch){if(ch=='-')f=-1;ch=getchar();}
    while('0'<=ch && ch<='9'){x=x*10+(ch^48);ch=getchar();}
    return x*f;
}
struct point{
	int coord[4],mn[4],id,l,r,mv,v,f;
	int &operator [](int x){
		return coord[x];
	}
	void init(){
		for(int i=0;i<=3;i++)
			mn[i]=coord[i];
		v=mv=f=0;
	}
};
bool operator <(point satori,point koishi){
	return satori[main_d]<koishi[main_d];
}
inline bool ecmp(point a,point b){
	for(int i=0;i<=3;i++)
		if(a[i]!=b[i])
			return 0;
	return 1;
}
inline bool judge(point satori,point koishi){
	for(int i=0;i<=3;i++)
		if(satori[i]>koishi[i])
			return 0;
	return 1;
}
inline bool valid(point satori,point koishi){
	for(int i=1;i<=3;i++)
		if(satori.mn[i]>koishi[i])
			return 0;
	return 1;
}
int p[N],sumt[5];
struct k_dimensional_tree{
	int root,maxval;
	point t[N];
	void updata(int x){
		for(int i=0;i<=3;i++)
			t[x].mn[i]=minn(t[x].mn[i],minn(
			t[t[x].l].mn[i],t[t[x].r].mn[i]));
		t[x].mv=maxx(t[x].v,maxx(t[t[x].l].mv,
		t[t[x].r].mv));
	}
	void push(int x){
		t[x].mv=maxx(t[x].v,maxx(t[t[x].l].mv,
		t[t[x].r].mv));
	}
	inline int check(int x,point p){
		if(!x)return x;
		int ret=1;
		for(int i=0;i<=3;i++)
			if(p[i]<t[x].mn[i])
				ret=0;
		return ret;
	}
	int biu(int l,int r,int d){
		main_d=d;
		int mid=l+r>>1,nxt;
		nth_element(t+l,t+mid,t+r+1);
		nxt=d+1;
		if(nxt==4)
			nxt=1;
		t[mid].init();
		if(l<mid)
			t[mid].l=biu(l,mid-1,nxt),
			t[t[mid].l].f=mid;
		if(mid<r)
			t[mid].r=biu(mid+1,r,nxt),
			t[t[mid].r].f=mid;
		updata(mid);
		return mid;
	}
	void query(int x,point p,int d){
		sumt[0]++;
		if(judge(t[x],p)&&maxval<t[x].v)
			maxval=t[x].v;
		int nxt=d+1;
		if(nxt==4)
			nxt=1;
		if(p[d]>=t[x][d]){
			int a=t[x].l,b=t[x].r;
			if(t[a].mv<t[b].mv)
				swap(a,b);
			if(valid(t[a],p)&&t[a].mv>maxval)
				query(a,p,nxt);
			if(valid(t[b],p)&&t[b].mv>maxval)
				query(b,p,nxt);
		}else{
			if(valid(t[t[x].l],p)&&t[t[x].l].mv>maxval)
				query(t[x].l,p,nxt);
		}
	}
	void modify(int x,int v){
		t[x].v=v;
		push(x);
		while(x=t[x].f)
			push(x);
	}
}koishi;
bool pcmp(int a,int b){
	for(int i=0;i<=3;i++){
		if(koishi.t[a][i]!=koishi.t[b][i]){
			return koishi.t[a][i]<koishi.t[b][i];
		}
	}
	return 0;
}
int main(){
	int n=read();
	for(int i=1;i<=n;i++){
		p[i]=i;
		for(int j=0;j<=3;j++)
			koishi.t[i][j]=read();
		koishi.t[i].id=i;
	}
	koishi.root=koishi.biu(1,n,1);
	sort(p+1,p+n+1,pcmp);
	for(int i=1;i<=n;i++){
		koishi.maxval=0;
		int o=p[i];
		koishi.query(koishi.root,koishi.t[p[i]],1);
		koishi.modify(p[i],koishi.maxval+1);
	}
	printf("%d\n",koishi.t[koishi.root].mv);
	return 0;
}