主席树学习笔记

【前言】

主席树,又称可持久化线段树,是最重要的可持久化数据结构之一。

和大多数可持久化数据结构类似,主席树保存了线段树的所有历史形态,从而可以维护许多特殊的东东。

个人认为其实比线段树还要简单啦。

【前置芝士】

  1. 线段树。(这个是必备的吧)
  2. 权值线段树。(其实没什么区别)
  3. 动态开点线段树。(其实可以不了解)
  4. 离散化等基本数据结构题操作。

显然这些都很简单啊。

(如果之前没有了解过可持久化数据结构,建议先学可持久化 0/1 trie 树,它的思想和主席树是一样一样的)

【主席树】

【引入】

板子题

静态查询区间 \(k\) 大值。

假设没有区间操作,只需要维护全局 \(k\) 大值,这显然是权值线段树的基本操作。

现在有区间操作,我们可以利用前缀和的思想,开 \(n\) 棵权值线段树分别维护位置 \(1\sim i\)

然后查询的时候边减边判断就好了,可惜这样你的空间吃不消。

所以就有主席树的诞生。

【主要思想】

考虑节省空间,我们发现每次插入一个数时,发生改变的至多只有 \(\log n\) 个节点,

其它节点相当于是在上一次的基础上直接复制下来的,也就没必要保留了。

只需要从当前节点指向历史上某棵线段树的节点即可,这样每次新增节点至多只有 \(\log n\) 个。

保证了空间复杂度。

关键主要思想它就这么多。

【代码实现】

以模板题为例,其实简单到极致。

因为本质上还是 \(n\) 棵线段树,只是共用了大量节点,所以还是要维护 \(n\) 个根。

先不管那些,初始化时先建棵树:

int root[N];
int ls[N * 40], rs[N * 40], sum[N * 40];

int build(int l, int r){
	int p = ++tot;
	sum[p] = 0;
	if(l == r) return p;
	int mid = (l + r) >> 1;
	ls[p] = build(l, mid);
	rs[p] = build(mid+1, r);
	return p;
}

root[0] = build(1, t); // t 是值域最大值。

然后考虑在上一棵树的基础上增添节点:

int insert(int last, int l, int r, int k, int v){
	int p = ++tot;
	sum[p] = sum[last]; ls[p] = ls[last]; rs[p] = rs[last];
    //直接把上一次的节点引用过来。
	if(l == r){
		sum[p] += v;//单店修改。
		return p;
	}
	int mid = (l + r) >> 1;
	if(k <= mid) ls[p] = insert(ls[last], l, mid, k, v);
	else rs[p] = insert(rs[last], mid+1, r, k, v);
	sum[p] = sum[ls[p]] + sum[rs[p]];
	return p;
}

for(int i=1; i<=n; i++)
	root[i] = insert(root[i-1], 1, t, x, 1);
	// x 为当前插入的数值。

比较核心的部分是查询,其实每道题的查询都略有不同,但是核心是前缀和思想,将 sum[r] - sum[l - 1]

int ask(int p, int q, int l, int r, int k){
	if(l == r) return l;
	int mid = (l + r) >> 1;
	int lcnt = sum[ls[p]] - sum[ls[q]];
	if(k <= lcnt) return ask(ls[p], ls[q], l, mid, k);
	else return ask(rs[p], rs[q], mid+1, r, k - lcnt);
}

int l = read(), r = read(), k = read();
int ans = ask(root[r], root[l-1], 1, t, k);

扔一下总的代码吧。

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N = 200010;

int n, m, tot;
int a[N], b[N], root[N];
int ls[N * 40], rs[N * 40], sum[N * 40];

int read(){
	int x=0,f=1;char c=getchar();
	while(c<'0' || c>'9') f=(c=='-')?-1:1,c=getchar();
	while(c>='0' && c<='9') x=x*10+c-48,c=getchar();
	return x*f;
}

int build(int l, int r){
	int p = ++tot;
	sum[p] = 0;
	if(l == r) return p;
	int mid = (l + r) >> 1;
	ls[p] = build(l, mid);
	rs[p] = build(mid+1, r);
	return p;
}

int insert(int last, int l, int r, int k, int v){
	int p = ++tot;
	sum[p] = sum[last]; ls[p] = ls[last]; rs[p] = rs[last];
	if(l == r){
		sum[p] += v;
		return p;
	}
	int mid = (l + r) >> 1;
	if(k <= mid) ls[p] = insert(ls[last], l, mid, k, v);
	else rs[p] = insert(rs[last], mid+1, r, k, v);
	sum[p] = sum[ls[p]] + sum[rs[p]];
	return p;
}

int ask(int p, int q, int l, int r, int k){
	if(l == r) return l;
	int mid = (l + r) >> 1;
	int lcnt = sum[ls[p]] - sum[ls[q]];
	if(k <= lcnt) return ask(ls[p], ls[q], l, mid, k);
	else return ask(rs[p], rs[q], mid+1, r, k - lcnt);
}

int main(){
	n = read(), m = read();
	for(int i=1; i<=n; i++) a[i] = b[i] = read();
	sort(b+1, b+n+1);
	int t = unique(b+1, b+n+1) - (b+1);
	root[0] = build(1, t);
	for(int i=1; i<=n; i++){
		int x = lower_bound(b+1, b+t+1, a[i]) - b;
		root[i] = insert(root[i-1], 1, t, x, 1);
	}
	for(int i=1; i<=m; i++){
		int l = read(), r = read(), k = read();
		int ans = ask(root[r], root[l-1], 1, t, k);
		printf("%d\n", b[ans]);
	}
	return 0;
}

【时空复杂度】

假设查询次数与序列总长同级。

显然时间复杂度为 \(O(n\log n)\),和单棵权值线段树一样。

空间复杂度为预处理建树和新增节点总和,为 \(O(n+n\log n)\)

【简单例题】

【例题一】

粟粟的书架

强行将题目分为两问

  1. 查询 \(n\times m\) 的矩阵中某个子矩阵。(\(n,m\leq 200\),询问数量 \(M\leq 2\times 10^5\)
  2. 查询序列 \(n\) 中某个区间。(\(n\leq 5\times 10^5\),询问数量 \(M\leq 2\times 10^4\)

问题都是:从所查询区域中选出尽量少的数和 \(\geq k\),输出所选数的数量,或者判断无解。

其中每个数 \(a_i\leq 1000\)\(k\leq 10^9\)

难度在于问题转化上。

首先考虑问题,发现肯定是贪心地从大数往小数取,可惜如果这样两问的时间复杂度都会太高。

再考虑二分答案,每次二分所选择的数字中最小的那个数,问题转化为判定是否可行。

显然这个问题具有单调性,而且最小的数字越大,数的总数也就越少,满足题意。

两问都用到这个思想,但具体处理方法不太相同:

  1. 直接上二维前缀和,需要同时维护数的总和数的数量
  2. 区间问题用主席树,同时也要维护数的总和数的数量

第一问预处理时间复杂度为 \(O(nm\times 1000)\),查询时间复杂度为 \(O(M\log 1000)\)

第二问预处理时间复杂度为 \(O(n\log 1000)\),查询时间复杂度为 \(O(M\log 1000)\)

问题统计时有些技巧,参考代码见下。

const int N = 500010;
const int M = N * 40;

int n, m, q;
int a[210][210];
int val[210][210][1010], cnt[210][210][1010];

int Sum(int x1, int y1, int x2, int y2, int k){
    return val[x2][y2][k] - val[x1-1][y2][k] - val[x2][y1-1][k] + val[x1-1][y1-1][k];
}

int Cnt(int x1, int y1, int x2, int y2, int k){
    return cnt[x2][y2][k] - cnt[x1-1][y2][k] - cnt[x2][y1-1][k] + cnt[x1-1][y1-1][k];
}

void work1(){
    for(int i=1; i<=n; i++)
        for(int j=1; j<=m; j++) a[i][j] = read();
    for(int k=1; k<=1000; k++)
        for(int i=1; i<=n; i++)
            for(int j=1; j<=m; j++){
                val[i][j][k] = val[i-1][j][k] + val[i][j-1][k] - val[i-1][j-1][k] + (a[i][j] >= k ? a[i][j] : 0);
                cnt[i][j][k] = cnt[i-1][j][k] + cnt[i][j-1][k] - cnt[i-1][j-1][k] + (a[i][j] >= k ? 1 : 0);
            }
    while(q--){
        int x1 = read(), y1 = read(), x2 = read(), y2 = read(), c = read();
        if(Sum(x1, y1, x2, y2, 1) < c) {puts("Poor QLW"); continue;}
        int l = 1, r = 1000;
        while(l < r){
            int mid = (l + r + 1) >> 1;
            if(Sum(x1, y1, x2, y2, mid) >= c) l = mid;
            else r = mid - 1;
        }
        int ans = Cnt(x1, y1, x2, y2, l) - (Sum(x1, y1, x2, y2, l) - c) / l;
        printf("%d\n", ans);
    }
}

int tot, T = 1000;
int p[N], root[N];
struct Tree{int ls, rs, cnt, sum;} t[M];

int build(int l, int r){
    int p = ++ tot;
    t[p].sum = t[p].cnt = 0;
    if(l == r) return p;
    int mid = (l + r) >> 1;
    t[p].ls = build(l, mid);
    t[p].rs = build(mid+1, r);
    return p;
}

int insert(int last, int l, int r, int k, int v){
    int p = ++ tot;
    t[p] = t[last];
    if(l == r){
        t[p].sum += v;
        t[p].cnt ++;
        return p;
    }
    int mid = (l + r) >> 1;
    if(k <= mid) t[p].ls = insert(t[last].ls, l, mid, k, v);
    else t[p].rs = insert(t[last].rs, mid+1, r, k, v);
    t[p].sum = t[t[p].ls].sum + t[t[p].rs].sum;
    t[p].cnt = t[t[p].ls].cnt + t[t[p].rs].cnt;
    return p;
}

int ask(int p, int q, int l, int r, int k){
    if(l == r) return (k - 1) / l + 1;
    int mid = (l + r) >> 1;
    int rsum = t[t[p].rs].sum - t[t[q].rs].sum;
    int rcnt = t[t[p].rs].cnt - t[t[q].rs].cnt;
    if(k <= rsum)
        return ask(t[p].rs, t[q].rs, mid+1, r, k);
    else
        return ask(t[p].ls, t[q].ls, l, mid, k - rsum) + rcnt;
}

void work2(){
    for(int i=1; i<=m; i++) p[i] = read();
    root[0] = build(1, T);
    for(int i=1; i<=m; i++)
        root[i] = insert(root[i-1], 1, T, p[i], p[i]);
    while(q--){
        int l, r, c;
        read(), l = read(), read(), r = read(), c = read();
        if(t[root[r]].sum - t[root[l-1]].sum < c) {puts("Poor QLW"); continue;}
        printf("%d\n", ask(root[r], root[l-1], 1, T, c));
    }
}

int main(){
    n = read(), m = read(), q = read();
    if(n > 1) work1();
    else work2();
    return 0;
}

【例题二】

森林

维护一个森林,支持连边查询两点路径上点权 \(k\) 大值

挺考验码力的一道题。

先不考虑连边操作,其实就变成了这道题

套路和大部分的树上题目差不多,主席树维护节点到根那条路径的点权,那么:

\[ans = sum(u)+sum(v)-sum(lca)-sum(fa_{lca}) \]

简单画个图就明白了。

关键是它还有连边操作,考虑每次将一个节点当做另一个节点的子节点,然后自根到底暴力重新维护一遍。

可惜这样时间复杂度最坏为 \(O(n^2\log n)\)

如果采用启发式合并呢?这样时间复杂度变为了很好的 \(O(n\log^2 n)\),只需要在根节点记录总子树大小。

这道题就基本上结束了,维护连通块用并查集,因为有连边操作,求 \(\rm lca\) 用倍增。

参考代码:

const int N = 80010;

int n, m, q, t;
int a[N], b[N];
int fa[N], sz[N], dep[N];
int st[N][30];
bool vis[N];
int head[N], cnt;
struct Edge{int nxt, to;} ed[N * 4];

int tot, root[N];
struct Tree{int ls, rs, sum;} tr[N * 600];

void add(int u, int v){
    ed[++cnt] = (Edge){head[u], v};
    head[u] = cnt;
}

int build(int l, int r){
    int p = ++tot;
    tr[p].sum = 0;
    if(l == r) return p;
    int mid = (l + r) >> 1;
    tr[p].ls = build(l, mid);
    tr[p].rs = build(mid+1, r);
    return p;
}

int insert(int last, int l, int r, int k){
    int p = ++tot;
    tr[p] = tr[last];
    if(l == r){
        tr[p].sum ++;
        return p;
    }
    int mid = (l + r) >> 1;
    if(k <= mid) tr[p].ls = insert(tr[last].ls, l, mid, k);
    else tr[p].rs = insert(tr[last].rs, mid+1, r, k);
    tr[p].sum = tr[tr[p].ls].sum + tr[tr[p].rs].sum;
    return p; 
}

int ask(int p, int q, int lastp, int lastq, int l, int r, int k){
    if(l == r) return l;
    int mid = (l + r) >> 1;
    int lcnt = tr[tr[p].ls].sum + tr[tr[q].ls].sum - tr[tr[lastp].ls].sum - tr[tr[lastq].ls].sum;
    if(k <= lcnt)
        return ask(tr[p].ls, tr[q].ls, tr[lastp].ls, tr[lastq].ls, l, mid, k);
    else 
        return ask(tr[p].rs, tr[q].rs, tr[lastp].rs, tr[lastq].rs, mid+1, r, k - lcnt);
}

void dfs(int u, int Fa, int Top){
    vis[u] = true;
    dep[u] = dep[Fa] + 1;
    st[u][0] = Fa;
    for(int i=1; i<=20; i++) st[u][i] = st[st[u][i-1]][i-1];
    sz[Top] ++;
    fa[u] = Fa;
    root[u] = insert(root[Fa], 1, t, a[u]);
    for(int i=head[u]; i; i=ed[i].nxt)
        if(ed[i].to != Fa) dfs(ed[i].to, u, Top);
}

int Get(int x){
    if(x == fa[x]) return x;
    return fa[x] = Get(fa[x]);
}

int Get_Lca(int x, int y){
    if(dep[x] < dep[y]) swap(x, y);
    for(int i=20; i>=0; i--)
        if(dep[st[x][i]] >= dep[y]) x = st[x][i];
    if(x == y) return x;
    for(int i=20; i>=0; i--)
        if(st[x][i] != st[y][i]) 
            x = st[x][i], y = st[y][i];
    return st[x][0];
}

int main(){
    read();
    n = read(), m = read(), q = read();
    for(int i=1; i<=n; i++) a[i] = b[i] = read();
    sort(b+1, b+n+1);
    t = unique(b+1, b+n+1) - (b+1);
    for(int i=1; i<=n; i++)
        a[i] = lower_bound(b+1, b+t+1, a[i]) - b;
    for(int i=1; i<=m; i++){
        int u = read(), v = read();
        add(u, v); add(v, u);
    }
    root[0] = build(1, t);
    for(int i=1; i<=n; i++)
        if(!vis[i]) dfs(i, 0, i), fa[i] = i;
    int ans = 0;
    char str[5];
    while(q--){
        scanf("%s", str);
        if(str[0] == 'Q'){
            int x = read(), y = read(), k = read();
            x ^= ans, y ^= ans, k ^= ans;
            int Lca = Get_Lca(x, y);
            ans = ask(root[x], root[y], root[Lca], root[st[Lca][0]], 1, t, k);
            ans = b[ans];
            printf("%d\n", ans);
        }else{
            int x = read(), y = read();
            x ^= ans, y ^= ans;
            add(x, y), add(y, x);
            int fx = Get(x), fy = Get(y);
            if(sz[fx] < sz[fy])
                swap(x, y), swap(fx, fy);
            dfs(y, x, fx);
        }
    }
    return 0;
}

【例题三】

Sign on Fence

要你在区间 \([l,r]\) 内选一个长度为 \(k\) 的区间,求区间最小数的最大值

最小值最大想到二分,将比 \(\rm mid\) 小的点设为 \(0\),其余点设为 \(1\),那么只需要判定是否有连续长度 \(\geq k\)\(1\) 串。

可惜如果这样每次暴力扫一遍建树,单次询问复杂度就达到了 \(O(n\log n)\)

但是其实每次二分得到的结果都差不多,那么其实可以重复利用的对吧。

考虑根据位置建立以值域为根的主席树,这和平常的主席树恰好相反,建树方式也恰好相反。

这样之后,每次二分值域只需要考虑以 \(\rm mid\) 为根的线段树中 \([l,r]\) 区间内是否有连连续长度 \(\geq k\)\(1\)

这是线段树的基本操作了,时间复杂度为优秀的 \(O(n\log^2 n)\)

const int N = 100010;

int n, q, T, tot;
int a[N], b[N], rk[N], root[N];
struct Tree{
    int ls, rs, len;
    int dat, ldat, rdat;
}t[N * 40];

int build(int l, int r){
    int p = ++tot;
    t[p].dat = t[p].ldat = t[p].rdat = 0;
    if(l == r){
        t[p].len = 1;
        return p;
    }
    int mid = (l + r) >> 1;
    t[p].ls = build(l, mid);
    t[p].rs = build(mid+1, r);
    t[p].len = t[t[p].ls].len + t[t[p].rs].len;
    return p;
}

void push_up(int p){
    int l = t[p].ls, r = t[p].rs;
    t[p].ldat = (t[l].ldat == t[l].len) ? t[l].ldat + t[r].ldat : t[l].ldat;
    t[p].rdat = (t[r].rdat == t[r].len) ? t[r].rdat + t[l].rdat : t[r].rdat;
    t[p].dat = max(t[l].dat, max(t[r].dat, t[l].rdat + t[r].ldat));
}

int insert(int last, int l, int r, int k){
    int p = ++tot;
    t[p] = t[last];
    if(l == r){
        t[p].dat = t[p].ldat = t[p].rdat = 1;
        return p;
    }
    int mid = (l + r) >> 1;
    if(k <= mid) t[p].ls = insert(t[last].ls, l, mid, k);
    else t[p].rs = insert(t[last].rs, mid+1, r, k);
    push_up(p);
    return p;
}

Tree merge(Tree a, Tree b){
    Tree c;
    c.ldat = (a.ldat == a.len) ? a.ldat + b.ldat : a.ldat;
    c.rdat = (b.rdat == b.len) ? a.rdat + b.rdat : b.rdat;
    c.dat = max(a.dat, max(b.dat, a.rdat + b.ldat));
    return c;
}

Tree ask(int p, int l, int r, int L, int R){
    if(l == L && r == R) return t[p];
    int mid = (l + r) >> 1;
    if(R <= mid)
        return ask(t[p].ls, l, mid, L, R);
    else if(L > mid)
        return ask(t[p].rs, mid+1, r, L, R);
    else
        return merge(ask(t[p].ls, l, mid, L, mid), ask(t[p].rs, mid+1, r, mid+1, R));
}

bool cmp(int x, int y) {return a[x] < a[y];}

int main(){
    n = read();
    for(int i=1; i<=n; i++) 
        a[i] = b[i] = read(), rk[i] = i;
    sort(b+1, b+n+1);
    sort(rk+1, rk+n+1, cmp);
    root[n + 1] = build(1, n);
    for(int i=n; i>=1; i--) 
        root[i] = insert(root[i+1], 1, n, rk[i]);
    q = read();
    while(q--){
        int L = read(), R = read(), k = read();
        int l = 1, r = n;
        while(l < r){
            int mid = (l + r + 1) >> 1;
            if(ask(root[mid], 1, n, L, R).dat >= k) l = mid;
            else r = mid - 1;
        }
        printf("%d\n", b[l]);
    }
    return 0;
}

【例题四】

Army Creation

每次问区间 \([l,r]\) 内最多可以选多少个数,满足同一个数的出现次数不超过 \(k\)

蒟蒻的第一道黑题耶,其实挺套路的。

假设 \(pos_i\) 表示第 \(i\) 个点的前面第 \(k\) 个与自己相等的节点,特别的,若前面没有 \(k\) 个点就令 \(pos_i=0\)

那么题目变为查询区间 \([l,r]\)\(pos_i<l\) 的有多少个数,主席树维护即可。

const int N = 100010;

int n, k, q, tot;
int root[N];
vector<int>pos[N];
struct Tree{int ls, rs, sum;} t[N * 40];

int build(int l, int r){
    int p = ++tot;
    t[p].sum = 0;
    if(l == r) return p;
    int mid = (l + r) >> 1;
    t[p].ls = build(l, mid);
    t[p].rs = build(mid+1, r);
    return p;
}

int insert(int last, int l, int r, int k){
    int p = ++tot;
    t[p] = t[last];
    if(l == r){
        t[p].sum ++;
        return p;
    }
    int mid = (l + r) >> 1;
    if(k <= mid) t[p].ls = insert(t[last].ls, l, mid, k);
    else t[p].rs = insert(t[last].rs, mid+1, r, k);
    t[p].sum = t[t[p].ls].sum + t[t[p].rs].sum;
    return p;
}

int ask(int p, int q, int l, int r, int limit){
    if(r <= limit) return t[p].sum - t[q].sum;
    int mid = (l + r) >> 1;
    int now = ask(t[p].ls, t[q].ls, l, mid, limit);
    if(limit > mid) now += ask(t[p].rs, t[q].rs, mid+1, r, limit);
    return now;
}

int main(){
    n = read(), k = read();
    root[0] = build(1, n);
    for(int i=1; i<=n; i++){
        int x = read();
        pos[x].push_back(i);
        int sz = pos[x].size();
        int p = sz > k ? pos[x][sz - k - 1] : 0;
        root[i] = insert(root[i-1], 0, n, p);
    }
    q = read();
    int ans = 0;
    while(q--){
        int l = read(), r = read();
        l = (l + ans) % n + 1;
        r = (r + ans) % n + 1;
        if(l > r) swap(l, r);
        ans = ask(root[r], root[l-1], 0, n, l-1);
        printf("%d\n", ans);
    }
    return 0;
}

【总结】

主席树的简单应用介绍到此结束。

可以发现题目中大部分没有修改操作(特别是区间修改),这是因为一般的主席树难以支持。

更好的方法是树状数组套主席树,有时可以吊打 整体二分 和 CDQ 分治。(有时又被它们吊打

推荐个题单

完结撒花。

posted @ 2021-03-16 16:43  LPF'sBlog  阅读(527)  评论(0编辑  收藏  举报