KD-Tree 多维树

更新日志

2024/11/12:更新建树操作以及查询操作。

2024/11/13:重构!


思路

KDT,就是多维树,顾名思义,里面的节点都具有多维信息。

说人话,就是每个节点都存了多个变量,并且所有节点拥有的各个变量含义相同,可视作多个“维度”。

KDT的形式类似于二叉搜索树,在某一维度上,左子树内的信息皆小于等于当前节点,右子树同理。

具体维度的划分将会在建树-优化部分详细讲解。

节点信息

与线段树不同,KDT的每一个节点都代表了一组单独的信息。

我们将信息分为两类——维度信息答案信息

前者就是基本的维度信息,建树时就用他们来划分区间,查询时判断的也是他们是否在查询区间内,总的来说,作为条件存在。

而后者就是用来统计答案的,不必多说。

优化

我们考虑像线段树那样,判断查询区间与当前区间的包含关系,进行减枝优化。

具体的,我们额外储存每一个节点所代表区间的信息

对于维度信息,我们通常储存整个区间内这一维度的最大值以及最小值

对于答案信息,我们视需要而定。如果题目要求和,那就储存区间内信息和。题目要求最大值,那就求区间内最大值。以此类推。

具体的优化过程见查询-优化部分。

建树

对于一个区间,我们先选取一个维度作为标准,找出其中位数作为这一棵子树的根节点,再利用快速排序的思想递归建树即可。

可以把所有节点储存在一个数组中,每次返回中位数所在下标(先把中位数放到中位之后的下标!),每次找节点直接在原数组中找即可,无需新建节点数组。

(像插入,也可以准备好节点数组,每插入一个就往后一个空节点填充信息即可。详见插入部分。)

优化

为了保证整棵树尽可能平衡,我们往往采用轮换选维的方式。比如说,这一节点使用 \(k\) 维度,那么其(至多)两个子结点代表区间就使用 \((k+1)\%K\) 的维度。(\(K\) 表示总维数)

可以在递归之后像线段树那样 pushup 一下。

查询

我们从根节点开始,每次判断中位数是否在区间内。这是最朴素的做法。

然额复杂度直接爆表

优化

建树的时候不是储存了区间信息吗,他们的作用这就来了。

首先,我们判断这个区间与查询区间的包含关系。

通常情况下,我们会利用各个维度的最大值与最小值判断。查询区间必然是在每个维度上都有要求的(上界或下界,或者两个都有),看是不是完全被包含、完全没包含即可。

假如完全被包含,直接返回区间答案信息。否则,视要求返回无价值信息。(\(\rm{e.g.}\)求最大值时直接返回\(\rm{-inf}\)

假如不是,我们就递归查询左右两棵子树。但是,当前节点是没有被包含在左右子树之内的,应该额外判定当前节点并更新答案

插入

tips:把一棵树拍扁指的是把这棵树中的所有节点都存到一个临时数组中,并删除树内(包括根节点与其夫节点)的联系。相当于把这棵子树删了。

tips:重构方法就是直接用临时数组里储存的节点编号build一棵新数出来。

额,不优化的做法就是直接把整棵树拍扁重构。复杂度上天。

根号重构

设计一个重构阀值,如果平衡度到达这个阀值就重构整棵树。

不太熟悉,可以借鉴OI-WIKI

二进制分组

思路

2048玩过吧?合成大西瓜玩过吧?对就是这种。

具体地说,我们考虑维护一个所有树的大小均为 \(2\) 的幂次的树林。

维护方法:一旦发现有两棵大小相同的树,就把它们合并(拍扁重构)。

每次插入的时候都新建一棵大小为 \(1\) 的树即可。

但这种方法实现起来很麻烦,比如说我们每次都要去找与它相同大小的树合并……

实现(优化)

事实上,我们发现每种大小的树最多只有一棵,那我们就可以优化一下:

我们先开一个数组\(rt\),其中 \(rt_i\) 表示大小为 \(2^i\) 的树的根节点。如果没有,就设作 \(0\)

同时,我们开一个临时数组,用来储存这一轮需要重构的所有节点。

每次插入时,我们都模拟出现了一个大小为 \(2^0\) 的树,假如已经存在了一棵树,那么它们就直接合并,产生一棵 \(2^1\) 的树……一直判断即可。

这个模拟的过程可以这么做:从小到大遍历 \(rt\),假如当前位置为空,就把最终合并出的树在这里建出来。否则,就合并二者——直接把里面原来存的树拍扁加入临时数组,并接着找下一个位子。(因为合并之后产生了一棵大小为 \(2^{i+1}\))的树。

另外,你可以认为临时数组中存储着新树的所有节点,所以临时数组就代表着未产生的新树。

应该挺详细了。

细节

就讲一个实现找中位数、让其左侧均小、右侧均大的快速方法吧。

nth_element()

对,有这么一个STL函数。

实现方法如下:(就不加注释了,应该看得懂)

nth_element(arr+l,arr+m,arr+r+1,cmp);

然后 \(m\) 这个下标对应元素就成了中位数,整个数组也满足要求了。

模板

放一个模板题,二维空间插入点并找区间点权和。

#define chmin(a,b) a=min(a,b)
#define chmax(a,b) a=max(a,b)

下方代码前提☝

const int N=2e5+5,K=2,T=20;

int n;
int ans;

struct node{
    int v[K],mn[K],mx[K];
    int val,sum;
    int lson,rson;
}ns[N];

int ak;
inline bool cmp(int a,int b){return ns[a].v[ak]<ns[b].v[ak];}

struct kdtree{
    int ncnt,acnt;
    int arr[N];
    int rt[T];
    void update(int x){
        ns[x].sum=ns[x].val;
        if(ns[x].lson)ns[x].sum+=ns[ns[x].lson].sum;
        if(ns[x].rson)ns[x].sum+=ns[ns[x].rson].sum;
        for(int i=0;i<K;i++){
            ns[x].mn[i]=ns[x].mx[i]=ns[x].v[i];
            if(ns[x].lson)chmin(ns[x].mn[i],ns[ns[x].lson].mn[i]),chmax(ns[x].mx[i],ns[ns[x].lson].mx[i]);
            if(ns[x].rson)chmin(ns[x].mn[i],ns[ns[x].rson].mn[i]),chmax(ns[x].mx[i],ns[ns[x].rson].mx[i]);
        }
    }
    int build(int l,int r,int k=0){
        int m=l+r>>1;
        ak=k;nth_element(arr+l,arr+m,arr+r+1,cmp);
        if(l<m)ns[arr[m]].lson=build(l,m-1,(k+1)%K);
        if(m<r)ns[arr[m]].rson=build(m+1,r,(k+1)%K);
        update(arr[m]);
        return arr[m];
    }
    void append(int &x){
        arr[++acnt]=x;
        if(ns[x].lson)append(ns[x].lson);
        if(ns[x].rson)append(ns[x].rson);
        x=0;
    }
    void insert(int v[K],int val){
        int x=++ncnt;
        for(int k=0;k<K;k++)ns[x].v[k]=v[k];
        ns[x].val=val;
        acnt=0;
        arr[++acnt]=x;
        for(int i=0;i<T;i++){
            if(rt[i])append(rt[i]);
            else{
                rt[i]=build(1,acnt);
                break;
            }
        }
    }
    inline bool allout(int lb[K],int rb[K],int x){
        for(int k=0;k<K;k++)
            if(ns[x].mn[k]>rb[k]||ns[x].mx[k]<lb[k])return true;
        return false;
    }
    inline bool allin(int lb[K],int rb[K],int x){
        for(int k=0;k<K;k++)
            if(!(ns[x].mn[k]>=lb[k]&&ns[x].mx[k]<=rb[k]))return false;
        return true;
    }
    inline bool check(int lb[K],int rb[K],int x){
        for(int k=0;k<K;k++)
            if(!(ns[x].v[k]>=lb[k]&&ns[x].v[k]<=rb[k]))return false;
        return true;
    }
    int query(int lb[K],int rb[K],int x){
        if(allout(lb,rb,x))return 0;
        if(allin(lb,rb,x))return ns[x].sum;
        int res=0;
        if(check(lb,rb,x))res+=ns[x].val;
        if(ns[x].lson)res+=query(lb,rb,ns[x].lson);
        if(ns[x].rson)res+=query(lb,rb,ns[x].rson);
        return res;
    }
    inline int query(int lb[K],int rb[K]){
        int res=0;
        for(int i=0;i<T;i++)
            if(rt[i])res+=query(lb,rb,rt[i]);
        return res;
    }
}kdt;

例题

LG4148

代码

前注:非题解,不做详细讲解

#include<bits/stdc++.h>
using namespace std;

typedef long long ll;
typedef unsigned long long ull;
typedef __int128 i128;
typedef double db;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
typedef pair<int,ll> pil;
typedef pair<ll,int> pli;
template <typename Type>
using vec=vector<Type>;
template <typename Type>
using grheap=priority_queue<Type>;
template <typename Type>
using lrheap=priority_queue<Type,vector<Type>,greater<Type> >;
#define fir first
#define sec second
#define pub push_back
#define pob pop_back
#define puf push_front
#define pof pop_front
#define chmin(a,b) a=min(a,b)
#define chmax(a,b) a=max(a,b)
#define dprint(x) cout<<#x<<"="<<x<<"\n";

const int inf=0x3f3f3f3f;
const int mod=1e9+7/*998244353*/;

const int N=2e5+5,K=2,T=20;

int n;
int ans;

struct node{
    int v[K],mn[K],mx[K];
    int val,sum;
    int lson,rson;
}ns[N];

int ak;
inline bool cmp(int a,int b){return ns[a].v[ak]<ns[b].v[ak];}

struct kdtree{
    int ncnt,acnt;
    int arr[N];
    int rt[T];
    void update(int x){
        ns[x].sum=ns[x].val;
        if(ns[x].lson)ns[x].sum+=ns[ns[x].lson].sum;
        if(ns[x].rson)ns[x].sum+=ns[ns[x].rson].sum;
        for(int i=0;i<K;i++){
            ns[x].mn[i]=ns[x].mx[i]=ns[x].v[i];
            if(ns[x].lson)chmin(ns[x].mn[i],ns[ns[x].lson].mn[i]),chmax(ns[x].mx[i],ns[ns[x].lson].mx[i]);
            if(ns[x].rson)chmin(ns[x].mn[i],ns[ns[x].rson].mn[i]),chmax(ns[x].mx[i],ns[ns[x].rson].mx[i]);
        }
    }
    int build(int l,int r,int k=0){
        int m=l+r>>1;
        ak=k;nth_element(arr+l,arr+m,arr+r+1,cmp);
        if(l<m)ns[arr[m]].lson=build(l,m-1,(k+1)%K);
        if(m<r)ns[arr[m]].rson=build(m+1,r,(k+1)%K);
        update(arr[m]);
        return arr[m];
    }
    void append(int &x){
        arr[++acnt]=x;
        if(ns[x].lson)append(ns[x].lson);
        if(ns[x].rson)append(ns[x].rson);
        x=0;
    }
    void insert(int v[K],int val){
        int x=++ncnt;
        for(int k=0;k<K;k++)ns[x].v[k]=v[k];
        ns[x].val=val;
        acnt=0;
        arr[++acnt]=x;
        for(int i=0;i<T;i++){
            if(rt[i])append(rt[i]);
            else{
                rt[i]=build(1,acnt);
                break;
            }
        }
    }
    void insert(int x){
        acnt=0;
        arr[++acnt]=x;
        for(int i=0;i<T;i++){
            if(rt[i])append(rt[i]);
            else{
                rt[i]=build(1,acnt);
                break;
            }
        }
    }
    inline bool allout(int lb[K],int rb[K],int x){
        for(int k=0;k<K;k++)
            if(ns[x].mn[k]>rb[k]||ns[x].mx[k]<lb[k])return true;
        return false;
    }
    inline bool allin(int lb[K],int rb[K],int x){
        for(int k=0;k<K;k++)
            if(!(ns[x].mn[k]>=lb[k]&&ns[x].mx[k]<=rb[k]))return false;
        return true;
    }
    inline bool check(int lb[K],int rb[K],int x){
        for(int k=0;k<K;k++)
            if(!(ns[x].v[k]>=lb[k]&&ns[x].v[k]<=rb[k]))return false;
        return true;
    }
    int query(int lb[K],int rb[K],int x){
        if(allout(lb,rb,x))return 0;
        if(allin(lb,rb,x))return ns[x].sum;
        int res=0;
        if(check(lb,rb,x))res+=ns[x].val;
        if(ns[x].lson)res+=query(lb,rb,ns[x].lson);
        if(ns[x].rson)res+=query(lb,rb,ns[x].rson);
        return res;
    }
    inline int query(int lb[K],int rb[K]){
        int res=0;
        for(int i=0;i<T;i++)
            if(rt[i])res+=query(lb,rb,rt[i]);
        return res;
    }
}kdt;

int main(){
    ios::sync_with_stdio(false);
    cin.tie(0);cout.tie(0);
    cin>>n;
    int op;
    int a[K],b[K],c;
    while(true){
        cin>>op;
        if(op==1){
            cin>>a[0]>>a[1]>>c;
            a[0]^=ans;a[1]^=ans;c^=ans;
            kdt.insert(a,c);
        }else if(op==2){
            cin>>a[0]>>a[1]>>b[0]>>b[1];
            a[0]^=ans;a[1]^=ans;b[0]^=ans;b[1]^=ans;
            ans=kdt.query(a,b);
            cout<<ans<<"\n";
        }else break;
    }
    return 0;
}
posted @ 2024-11-12 15:24  HarlemBlog  阅读(12)  评论(0编辑  收藏  举报