KD-Tree 多维树
更新日志
2024/11/12:更新建树操作以及查询操作。
2024/11/13:重构!
思路
KDT,就是多维树,顾名思义,里面的节点都具有多维信息。
说人话,就是每个节点都存了多个变量,并且所有节点拥有的各个变量含义相同,可视作多个“维度”。
KDT的形式类似于二叉搜索树,在某一维度上,左子树内的信息皆小于等于当前节点,右子树同理。
具体维度的划分将会在建树-优化部分详细讲解。
节点信息
与线段树不同,KDT的每一个节点都代表了一组单独的信息。
我们将信息分为两类——维度信息与答案信息。
前者就是基本的维度信息,建树时就用他们来划分区间,查询时判断的也是他们是否在查询区间内,总的来说,作为条件存在。
而后者就是用来统计答案的,不必多说。
优化
我们考虑像线段树那样,判断查询区间与当前区间的包含关系,进行减枝优化。
具体的,我们额外储存每一个节点所代表区间的信息。
对于维度信息,我们通常储存整个区间内这一维度的最大值以及最小值。
对于答案信息,我们视需要而定。如果题目要求和,那就储存区间内信息和。题目要求最大值,那就求区间内最大值。以此类推。
具体的优化过程见查询-优化部分。
建树
对于一个区间,我们先选取一个维度作为标准,找出其中位数作为这一棵子树的根节点,再利用快速排序的思想递归建树即可。
可以把所有节点储存在一个数组中,每次返回中位数所在下标(先把中位数放到中位之后的下标!),每次找节点直接在原数组中找即可,无需新建节点数组。
(像插入,也可以准备好节点数组,每插入一个就往后一个空节点填充信息即可。详见插入部分。)
优化
为了保证整棵树尽可能平衡,我们往往采用轮换选维的方式。比如说,这一节点使用 \(k\) 维度,那么其(至多)两个子结点代表区间就使用 \((k+1)\%K\) 的维度。(\(K\) 表示总维数)
可以在递归之后像线段树那样 pushup
一下。
查询
我们从根节点开始,每次判断中位数是否在区间内。这是最朴素的做法。
然额复杂度直接爆表
优化
建树的时候不是储存了区间信息吗,他们的作用这就来了。
首先,我们判断这个区间与查询区间的包含关系。
通常情况下,我们会利用各个维度的最大值与最小值判断。查询区间必然是在每个维度上都有要求的(上界或下界,或者两个都有),看是不是完全被包含、完全没包含即可。
假如完全被包含,直接返回区间答案信息。否则,视要求返回无价值信息。(\(\rm{e.g.}\)求最大值时直接返回\(\rm{-inf}\))
假如不是,我们就递归查询左右两棵子树。但是,当前节点是没有被包含在左右子树之内的,应该额外判定当前节点并更新答案。
插入
tips:把一棵树拍扁指的是把这棵树中的所有节点都存到一个临时数组中,并删除树内(包括根节点与其夫节点)的联系。相当于把这棵子树删了。
tips:重构方法就是直接用临时数组里储存的节点编号build
一棵新数出来。
额,不优化的做法就是直接把整棵树拍扁重构。复杂度上天。
根号重构
设计一个重构阀值,如果平衡度到达这个阀值就重构整棵树。
不太熟悉,可以借鉴OI-WIKI。
二进制分组
思路
2048玩过吧?合成大西瓜玩过吧?对就是这种。
具体地说,我们考虑维护一个所有树的大小均为 \(2\) 的幂次的树林。
维护方法:一旦发现有两棵大小相同的树,就把它们合并(拍扁重构)。
每次插入的时候都新建一棵大小为 \(1\) 的树即可。
但这种方法实现起来很麻烦,比如说我们每次都要去找与它相同大小的树合并……
实现(优化)
事实上,我们发现每种大小的树最多只有一棵,那我们就可以优化一下:
我们先开一个数组\(rt\),其中 \(rt_i\) 表示大小为 \(2^i\) 的树的根节点。如果没有,就设作 \(0\)。
同时,我们开一个临时数组,用来储存这一轮需要重构的所有节点。
每次插入时,我们都模拟出现了一个大小为 \(2^0\) 的树,假如已经存在了一棵树,那么它们就直接合并,产生一棵 \(2^1\) 的树……一直判断即可。
这个模拟的过程可以这么做:从小到大遍历 \(rt\),假如当前位置为空,就把最终合并出的树在这里建出来。否则,就合并二者——直接把里面原来存的树拍扁加入临时数组,并接着找下一个位子。(因为合并之后产生了一棵大小为 \(2^{i+1}\))的树。
另外,你可以认为临时数组中存储着新树的所有节点,所以临时数组就代表着未产生的新树。
应该挺详细了。
细节
就讲一个实现找中位数、让其左侧均小、右侧均大的快速方法吧。
nth_element()
对,有这么一个STL
函数。
实现方法如下:(就不加注释了,应该看得懂)
nth_element(arr+l,arr+m,arr+r+1,cmp);
然后 \(m\) 这个下标对应元素就成了中位数,整个数组也满足要求了。
模板
放一个模板题,二维空间插入点并找区间点权和。
#define chmin(a,b) a=min(a,b)
#define chmax(a,b) a=max(a,b)
下方代码前提☝
const int N=2e5+5,K=2,T=20;
int n;
int ans;
struct node{
int v[K],mn[K],mx[K];
int val,sum;
int lson,rson;
}ns[N];
int ak;
inline bool cmp(int a,int b){return ns[a].v[ak]<ns[b].v[ak];}
struct kdtree{
int ncnt,acnt;
int arr[N];
int rt[T];
void update(int x){
ns[x].sum=ns[x].val;
if(ns[x].lson)ns[x].sum+=ns[ns[x].lson].sum;
if(ns[x].rson)ns[x].sum+=ns[ns[x].rson].sum;
for(int i=0;i<K;i++){
ns[x].mn[i]=ns[x].mx[i]=ns[x].v[i];
if(ns[x].lson)chmin(ns[x].mn[i],ns[ns[x].lson].mn[i]),chmax(ns[x].mx[i],ns[ns[x].lson].mx[i]);
if(ns[x].rson)chmin(ns[x].mn[i],ns[ns[x].rson].mn[i]),chmax(ns[x].mx[i],ns[ns[x].rson].mx[i]);
}
}
int build(int l,int r,int k=0){
int m=l+r>>1;
ak=k;nth_element(arr+l,arr+m,arr+r+1,cmp);
if(l<m)ns[arr[m]].lson=build(l,m-1,(k+1)%K);
if(m<r)ns[arr[m]].rson=build(m+1,r,(k+1)%K);
update(arr[m]);
return arr[m];
}
void append(int &x){
arr[++acnt]=x;
if(ns[x].lson)append(ns[x].lson);
if(ns[x].rson)append(ns[x].rson);
x=0;
}
void insert(int v[K],int val){
int x=++ncnt;
for(int k=0;k<K;k++)ns[x].v[k]=v[k];
ns[x].val=val;
acnt=0;
arr[++acnt]=x;
for(int i=0;i<T;i++){
if(rt[i])append(rt[i]);
else{
rt[i]=build(1,acnt);
break;
}
}
}
inline bool allout(int lb[K],int rb[K],int x){
for(int k=0;k<K;k++)
if(ns[x].mn[k]>rb[k]||ns[x].mx[k]<lb[k])return true;
return false;
}
inline bool allin(int lb[K],int rb[K],int x){
for(int k=0;k<K;k++)
if(!(ns[x].mn[k]>=lb[k]&&ns[x].mx[k]<=rb[k]))return false;
return true;
}
inline bool check(int lb[K],int rb[K],int x){
for(int k=0;k<K;k++)
if(!(ns[x].v[k]>=lb[k]&&ns[x].v[k]<=rb[k]))return false;
return true;
}
int query(int lb[K],int rb[K],int x){
if(allout(lb,rb,x))return 0;
if(allin(lb,rb,x))return ns[x].sum;
int res=0;
if(check(lb,rb,x))res+=ns[x].val;
if(ns[x].lson)res+=query(lb,rb,ns[x].lson);
if(ns[x].rson)res+=query(lb,rb,ns[x].rson);
return res;
}
inline int query(int lb[K],int rb[K]){
int res=0;
for(int i=0;i<T;i++)
if(rt[i])res+=query(lb,rb,rt[i]);
return res;
}
}kdt;
例题
代码
前注:非题解,不做详细讲解
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef __int128 i128;
typedef double db;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
typedef pair<int,ll> pil;
typedef pair<ll,int> pli;
template <typename Type>
using vec=vector<Type>;
template <typename Type>
using grheap=priority_queue<Type>;
template <typename Type>
using lrheap=priority_queue<Type,vector<Type>,greater<Type> >;
#define fir first
#define sec second
#define pub push_back
#define pob pop_back
#define puf push_front
#define pof pop_front
#define chmin(a,b) a=min(a,b)
#define chmax(a,b) a=max(a,b)
#define dprint(x) cout<<#x<<"="<<x<<"\n";
const int inf=0x3f3f3f3f;
const int mod=1e9+7/*998244353*/;
const int N=2e5+5,K=2,T=20;
int n;
int ans;
struct node{
int v[K],mn[K],mx[K];
int val,sum;
int lson,rson;
}ns[N];
int ak;
inline bool cmp(int a,int b){return ns[a].v[ak]<ns[b].v[ak];}
struct kdtree{
int ncnt,acnt;
int arr[N];
int rt[T];
void update(int x){
ns[x].sum=ns[x].val;
if(ns[x].lson)ns[x].sum+=ns[ns[x].lson].sum;
if(ns[x].rson)ns[x].sum+=ns[ns[x].rson].sum;
for(int i=0;i<K;i++){
ns[x].mn[i]=ns[x].mx[i]=ns[x].v[i];
if(ns[x].lson)chmin(ns[x].mn[i],ns[ns[x].lson].mn[i]),chmax(ns[x].mx[i],ns[ns[x].lson].mx[i]);
if(ns[x].rson)chmin(ns[x].mn[i],ns[ns[x].rson].mn[i]),chmax(ns[x].mx[i],ns[ns[x].rson].mx[i]);
}
}
int build(int l,int r,int k=0){
int m=l+r>>1;
ak=k;nth_element(arr+l,arr+m,arr+r+1,cmp);
if(l<m)ns[arr[m]].lson=build(l,m-1,(k+1)%K);
if(m<r)ns[arr[m]].rson=build(m+1,r,(k+1)%K);
update(arr[m]);
return arr[m];
}
void append(int &x){
arr[++acnt]=x;
if(ns[x].lson)append(ns[x].lson);
if(ns[x].rson)append(ns[x].rson);
x=0;
}
void insert(int v[K],int val){
int x=++ncnt;
for(int k=0;k<K;k++)ns[x].v[k]=v[k];
ns[x].val=val;
acnt=0;
arr[++acnt]=x;
for(int i=0;i<T;i++){
if(rt[i])append(rt[i]);
else{
rt[i]=build(1,acnt);
break;
}
}
}
void insert(int x){
acnt=0;
arr[++acnt]=x;
for(int i=0;i<T;i++){
if(rt[i])append(rt[i]);
else{
rt[i]=build(1,acnt);
break;
}
}
}
inline bool allout(int lb[K],int rb[K],int x){
for(int k=0;k<K;k++)
if(ns[x].mn[k]>rb[k]||ns[x].mx[k]<lb[k])return true;
return false;
}
inline bool allin(int lb[K],int rb[K],int x){
for(int k=0;k<K;k++)
if(!(ns[x].mn[k]>=lb[k]&&ns[x].mx[k]<=rb[k]))return false;
return true;
}
inline bool check(int lb[K],int rb[K],int x){
for(int k=0;k<K;k++)
if(!(ns[x].v[k]>=lb[k]&&ns[x].v[k]<=rb[k]))return false;
return true;
}
int query(int lb[K],int rb[K],int x){
if(allout(lb,rb,x))return 0;
if(allin(lb,rb,x))return ns[x].sum;
int res=0;
if(check(lb,rb,x))res+=ns[x].val;
if(ns[x].lson)res+=query(lb,rb,ns[x].lson);
if(ns[x].rson)res+=query(lb,rb,ns[x].rson);
return res;
}
inline int query(int lb[K],int rb[K]){
int res=0;
for(int i=0;i<T;i++)
if(rt[i])res+=query(lb,rb,rt[i]);
return res;
}
}kdt;
int main(){
ios::sync_with_stdio(false);
cin.tie(0);cout.tie(0);
cin>>n;
int op;
int a[K],b[K],c;
while(true){
cin>>op;
if(op==1){
cin>>a[0]>>a[1]>>c;
a[0]^=ans;a[1]^=ans;c^=ans;
kdt.insert(a,c);
}else if(op==2){
cin>>a[0]>>a[1]>>b[0]>>b[1];
a[0]^=ans;a[1]^=ans;b[0]^=ans;b[1]^=ans;
ans=kdt.query(a,b);
cout<<ans<<"\n";
}else break;
}
return 0;
}