主席树讲解

概念

可以看作是可持久化权值线段树
可持久化又可以分为部分可持久化和完全可持久化
部分可持久化————所有版本都可以访问,但是只有最新版本可以修改
完全可持久化————所有版本都既可以访问又可以修改

引入

以一道经典例题来讲述主席树————给定 \(n\) 个整数构成的序列 ,将对于指定的闭区间 \([l,r]\) 查询其区间内的第 \(k\) 小值。
原题链接————洛谷P3834

思路解析

利用主席树解决问题的主要思想是:每次插入数据时都要保存历史版本,以便于查询区间第 \(k\) 小值
假定当前数组里的数是25957 6405 15770 26287 26465
接下来进行两次查询,第一次是求区间 \([2,2]\) 中第1小的数,第二次是求区间 \([3,4]\) 中第1小的数
找第k小的数最重要的就是找到该数在原数组中的位置,我们可以利用前缀和区间减法的思路巧妙解决问题
即如果需要得到 \([l,r]\) 的统计信息,只需要用 \([1,r]\) 的信息减去 \([1,l-1]\) 的信息就行了。
考虑这样保存信息

先把原数组排序并去重,原序列就变成了 6405 15770 25957 26287 26465
这时,每一个数就对应了一个有序的位置
利用动态开点建树
注意主席树不能使用堆式存储法,就是说不能用 \(x\times2\)\(x\times2+1\) 来表示左右儿子,而是应该动态开点,并保存每个节点的左右儿子编号,对该顺序序列进行划分
img
build返回当前结点的编号,递归到最上层即为根结点编号
初始状态就是
img
以下结点和区间一一对应
1————[1,5]
2————[1,3]
7————[4,5]
3————[1,2]
6————[3,3]
8————[4,4]
9————[5,5]
4————[1,1]
5————[2,2]

既然要保存历史版本,那么每次输入时我们也要进行动态开点保存信息
当第一次输入25957时,我们可以二分找到它在顺序序列的位置,可以知道它是整个序列第3大的数
用sum[i]表示第i个结点的区间内有多少个归位的数
在更新结点时,本来应该是sum[1]++,sum[2]++,sum[6]++,因为他们都包含了3这个位置
由于要保存上一个版本,所以开出新的点10,11,12对应sum[10],sum[11],sum[12]
更新信息的代码是img
在更新操作的代码中,pre表示的是上一个输入操作后的线段树根节点信息,利用上一个信息更新当前输入操作后线段树的结点信息
如果当前k这个位置不大于mid,那么对应的结点肯定会在左边,因此更新左边的信息,否则更新右边的信息
在第一次操作后会变成
img
这个时候sum[10] = sum[11] = sum[12] = 1,而sum[1] = sum[2] = sum[6] = 0;
这样我们就实现了保存两个版本的信息

如果这个时候我们要查询区间[1,1]第1小的数
利用前缀和的思路,我们先找第0个版本和第一个版本的根结点,分别是1和10
sum[11] - sum[2] = 1 而1 <= 1
说明在区间[1,3]中有一个处于对应位置的数,继续向下递归则可以找到这个数
代码
img
在第二次插入时,我们输入6405,根据上述更新的思路,第二个版本的线段树变成了img
此时sum[13] = sum[15] = sum[16] = 1(因为6405是顺序序列位置为1的数)
但是注意此时 \(sum[14] = 2 = sum[15](最新版本) + sum[12](上个版本)\)
此时查询[1,2]中第2小的数
sum[14] - sum[2] = 2 <= 2
向左递归 sum[15] - sum[3] = 1 < 2 向右递归
此时在12号结点,对应区间是[3,3]对应的数是顺序序列的3号位——即25957
因此,该问题可以得到解决

空间复杂度

在主席树的创建中,我们不断使用动态开点来保留版本和更新版本。
而一棵线段树只会出现 \(2\times n-1\) 个结点,一共n次修改,每次修改最多增加 \(\lceil \log_2n\rceil + 1\) 个结点(即推到叶子结点)
因此最坏情况下会达到 \(2\times n-1 + n\times(\lceil \log_2n\rceil + 1)\)
在例题中 \(n = 10^{5}\),大概为 \(20\times10^{5}\)

完整代码

#include<bits/stdc++.h>
using namespace std;
#define int long long
#define endl '\n'
const int N = 2e5 + 5;
int rs[N<<5];
int ls[N<<5];
int sum[N<<5];
int s[N];
int cint[N];
int tot;
int rt[N];
int build(int l,int r){
    int root = ++tot;
    if(l == r) return root;
    int mid = l + (r - l >> 1);
    ls[root] = build(l,mid);
    rs[root] = build(mid+1,r);
    return root;
}
int update(int pre,int l,int r,int k){
    int root = ++tot;
    sum[root] = sum[pre];
    ls[root] = ls[pre],rs[root] = rs[pre];
    if(l == r){
        sum[root] += 1;
        return root;
    }
    int mid = l + (r - l >> 1);
    if(k <= mid) ls[root] = update(ls[pre],l,mid,k);
    else rs[root] = update(rs[pre],mid+1,r,k);
    sum[root] = sum[ls[root]] + sum[rs[root]];
    return root;
}
int query(int u,int v,int k,int l,int r){
    int mid = l + (r - l >> 1);
    int x = sum[ls[v]] - sum[ls[u]];
    if(l == r) return l;
    if(k <= x) return query(ls[u],ls[v],k,l,mid);
    else return query(rs[u],rs[v],k-x,mid+1,r);

}
inline int read(){
    int x=0,f=1;
    char ch=getchar();
    while(ch<'0'||ch>'9'){
        if(ch=='-')
            f=-1;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9'){
        x=x*10+ch-'0';
        ch=getchar();
    }
    return x*f;
    }
inline void print(int x){
    if(x<0){
        putchar('-');
        x=-x;
    }
    if(x>9)
        print(x/10);
        putchar(x%10+'0');
    }
signed main(){
    int n,m;
    n = read();
    m = read();
    for(int i = 1; i <= n; i++) s[i] = read(),cint[i] = s[i];
    sort(s+1,s+1+n);
    int len = unique(s+1,s+1+n) - (s+1);
    rt[0] = build(1,len);

    for(int i = 1; i <= n; i++){
        int p = lower_bound(s+1,s+len+1,cint[i]) - s;
        rt[i] = update(rt[i-1],1,len,p);
        
    }
    for(int i = 1; i <= m; i++){
        int l,r,k;
        l = read(),r = read(),k = read();
        int res = query(rt[l-1],rt[r],k,1,len);
        print(s[res]),printf("\n");
    }
    return 0;
}

posted @ 2022-08-31 22:37  Sun-Wind  阅读(82)  评论(0编辑  收藏  举报