Part 1 —— 关于可持久化所谓可持久化,就可以简单理解为记录历史版本的数据结构。

不过其作用不只是简单地查询历史信息,比如线段树就可以实现查询区间第 kk 大,区间不大于 kk 的数个数等。

那么话不多说,正式开始吧!

\text{Part 2 —— 可持久化线段树}Part 2 —— 可持久化线段树这也是很多人接触的第一个可持久化数据结构。
先来看第一个模板题:静态区间第 kk 小

在考虑这个问题之前,先来想一个简单点的:动态整体第 kk 小 (单点修改,查询整体第 kk 小)
这个东西有很多种做法,其中用线段树常数比较优秀。

首先当然是离散化,然后对于出现的一个数 xx,在线段树上的 a_x+1ax+1 (也就是权值线段树)。
此时查询线段树区间 [l,r][l,r] 的和,得到的就是值域在 [l,r][l,r] 中的数个数。

于是查询代码这么写就好了:

int query(int l,int r,int u,int k){
    if(l==r) return l; //左右端点相等,显然只有 1 个结果
    int mid = (l+r)>>1;
    if(sum[u<<1]>=k) return query(l,mid,u<<1,k); //如果左儿子值域内的数多于 k 个,那么答案肯定在左边
    return query(mid+1,r,u<<1|1,k-sum[u<<1]); //否则答案就在右边,但是要注意 k 要减去左边的数个数
}

调用 query(1,n,1,k),得到的即使第 kk 大在排序后数组的下标。

对于区间第 kk 大,可以想到一种极为暴力的做法:
开 nn 个线段树,第 ii 个线段树上只记录 a_iai,然后查询 [l,r][l,r] 时把第 ll 到第 rr 棵线段树对应节点的信息全加起来,就是区间 [l,r][l,r] 建出的线段树的信息。

但这样还是太 naive 了,我们发现这 nn 棵线段树有很多重复的节点,可不可以共用呢?
这就是我们可持久化线段树的做法了。

这是一只可爱的线段树:
 这时候要修改第 66 个位置,它的改变会影响到 [1,6],[4,6],[5,6],[6,6][1,6],[4,6],[5,6],[6,644 个节点,所以把它们都复制一份,得到一个新的版本。

复制出的节点的儿子和原节点一样,然后再上面做修改,也就是这样:
 可以发现,把原来 [1,6],[4,6],[5,6],[6,6][1,6],[4,6],[5,6],[6,6] 这 44 个节点挡住不看,剩下的这部分就是在第 66 个位置修改 11 次后得到的线段树。

那么,把序列中的数依次加进去,得到了 nn 个版本的线段树。
对于求区间 [l,r][l,r] 建出的线段树信息时,只需要把第 rr 个版本和第 l-1l1 个版本的对应节点做差即可。

于是代码也就可以写出来啦:

#include<cstdio> 
#include<iostream>
#include<cstring>
#include<algorithm>
#define ll long long
#define N 200003
#define reg register
#define mid ((l+r)>>1)
using namespace std;

int rt[N],ls[N*19],rs[N*19],sum[N*19]; //rt[i] 为第 i 个版本线段树的根,ls 和 rs 表示左右儿子
int a[N],b[N];
int n,q,cnt;

inline void read(int &x);
void print(int x);
void build(int &u,int l,int r);
void update(int &u,int pre,int l,int r,int k);
int query(int u,int v,int l,int r,int k);

int main(){
    int l,r,m,k,t;
    read(n),read(q);
    for(reg int i=1;i<=n;++i){
        read(a[i]);
        b[i] = a[i];
    }
    sort(b+1,b+1+n);
    m = unique(b+1,b+1+n)-b-1;
    build(rt[0],1,m); //建树
    for(reg int i=1;i<=n;++i){
        a[i] = lower_bound(b+1,b+1+m,a[i])-b;
        update(rt[i],rt[i-1],1,m,a[i]); //在第 a[i] 个位置 +1
    }
    while(q--){
        read(l),read(r),read(k);
        t = query(rt[l-1],rt[r],1,m,k); //上面提到的,第 r 个版本和 l-1 个版本相减
        print(b[t]),putchar('\n');
    }    
    return 0;
}

int query(int u,int v,int l,int r,int k){
    if(l==r) return l;
    int x = sum[ls[v]]-sum[ls[u]]; //两部分对应相减,注意是左儿子,和上面整体第 k 大的差不多
    if(x>=k) return query(ls[u],ls[v],l,mid,k);
    return query(rs[u],rs[v],mid+1,r,k-x);
}   

void update(int &u,int pre,int l,int r,int k){
    //pre 表示从哪个节点复制过来的
    //&u 的作用是设置新建节点标号
    u = ++cnt;
    ls[u] = ls[pre],rs[u] = rs[pre]; //左右儿子先设为和原节点一样,后面再改
    sum[u] = sum[pre]+1;
    if(l==r) return;
    if(k<=mid) update(ls[u],ls[pre],l,mid,k); //注意原节点也跟着走 左/右 儿子
    else update(rs[u],rs[pre],mid+1,r,k);
}

void build(int &u,int l,int r){
    u = ++cnt; //每新建一个节点都要 ++cnt
    if(l==r) return;
    build(ls[u],l,mid);
    build(rs[u],mid+1,r);
}

inline void read(int &x){
    x = 0;
    char c = getchar();
    while(c<'0'||c>'9') c = getchar();
    while(c>='0'&&c<='9'){
        x = (x<<3)+(x<<1)+(c^48);
        c = getchar();
    }
}

void print(int x){
    if(x>9) print(x/10);
    putchar(x%10+'0');
}

现在来搞定前面提到的:查询区间不大于 kk 的数个数。
对应题目:K-query (不过此题中求大于 kk 的个数,减一下就好了)

建树的方式和前面 完 全 一 致,来想想查询的函数怎么写:
查询不大于 kk 的数个数,也就是看值域 [1,k][1,k] 内有多少个数。

void query(int u,int v,int l,int r,int k){
    //此处的b数组还是排序后的数组
    if(l==r){
        if(b[l]<=k) ans += sum[v]-sum[u];
        //左右端点重合,直接看这个位置是否 <=k
        return;
    }
    if(k<=b[mid]) query(ls[u],ls[v],l,mid,k); //中间位置的数比 k 大的话,有贡献的部分只能在左边 
    else{
        ans += sum[ls[v]]-sum[ls[u]];
        //否则就把左边整个加上,再查询右边
        query(rs[u],rs[v],mid+1,r,k);
    }
}

考虑求树上路径的第 kk 大,也就是 Count on a tree

在树上我们还可以用前缀和的思想,在序列上的时候, uu 是从 u-1u1 复制过来的,现在从它的父亲复制过来即可。

要得到 u\rightarrow vuv 这条链上的点建出的线段树信息,只需要将 u,v,\text{lca}(u,v),\text{fa}(\text{lca}(u,v))u,v,lca(u,v),fa(lca(u,v)) 这几个节点对应 加/减 就好了。

代码如下:

#include<cstdio> 
#include<iostream>
#include<cstring>
#include<algorithm>
#include<vector>
#include<queue>
#define ll long long
#define N 200005
#define reg register
#define mid ((l+r)>>1)
using namespace std;

int fa[N],dep[N],size[N],son[N],top[N];
int a[N],b[N],ls[N*20],rs[N*20],sum[N*20],rt[N];
vector<int> adj[N];
int n,m,q,cnt;

struct Segment_Tree{
    void build(int &u,int l,int r){
        u = ++cnt;
        if(l==r) return;
        build(ls[u],l,mid);
        build(rs[u],mid+1,r);
    }

    void modify(int &u,int pre,int l,int r,int k){
        u = ++cnt;
        ls[u] = ls[pre],rs[u] = rs[pre];
        sum[u] = sum[pre]+1;
        if(l==r) return;
        if(k<=mid) modify(ls[u],ls[pre],l,mid,k);
        else modify(rs[u],rs[pre],mid+1,r,k);
    }

    int query(int u,int v,int p,int fp,int l,int r,int k){
        if(l==r) return l;
        int s = sum[ls[u]]+sum[ls[v]]-sum[ls[p]]-sum[ls[fp]];
        if(s>=k) return query(ls[u],ls[v],ls[p],ls[fp],l,mid,k);
        return query(rs[u],rs[v],rs[p],rs[fp],mid+1,r,k-s);
    }
}T;

void dfs1(int u,int f){
    fa[u] = f;
    dep[u] = dep[f]+1;
    size[u] = 1;
    T.modify(rt[u],rt[f],1,m,a[u]);
    int v,l = adj[u].size();
    for(int i=0;i<l;++i){
        v = adj[u][i];
        if(v==f) continue;
        dfs1(v,u);
        size[u] += size[v];
        if(size[v]>size[son[u]]) son[u] = v;
    }
}

void dfs2(int u,int f){
    top[u] = f;
    if(!son[u]) return;
    dfs2(son[u],f);
    int v,l = adj[u].size();
    for(int i=0;i<l;++i){
        v = adj[u][i];
        if(v==fa[u]||v==son[u]) continue;
        dfs2(v,v);
    }
}

inline int lca(int u,int v){
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]]) swap(u,v);
        u = fa[top[u]];
    }
    return dep[u]<dep[v]?u:v;
}

int main(){
    int u,v,k,p,ans = 0;
    scanf("%d%d",&n,&q);
    for(reg int i=1;i<=n;++i){
        scanf("%d",&a[i]);
        b[i] = a[i];
    }
    for(reg int i=1;i<n;++i){
        scanf("%d%d",&u,&v);
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    sort(b+1,b+1+n);
    m = unique(b+1,b+1+n)-b-1;
    for(reg int i=1;i<=n;++i)
        a[i] = lower_bound(b+1,b+1+m,a[i])-b;
    T.build(rt[0],1,m);
    dfs1(1,0),dfs2(1,1);
    while(q--){
        scanf("%d%d%d",&u,&v,&k);
        u ^= ans;
        p = lca(u,v);
        ans = b[T.query(rt[u],rt[v],rt[p],rt[fa[p]],1,m,k)];
        printf("%d\n",ans);
    }
    return 0;
}

那么接下来又是一个比较难的问题:Dynamic Ranking

对于之前静态的问题,我们使用前缀和,第 nn 个版本的线段树存了区间 [1,n][1,n] 的信息。
但是现在带修改,就要用树状数组维护前缀和。 虽然比较容易想,但是也很容易写挂。

所以有一些要注意的点:

1、修改的时候,每个节点是从它自己,而不是前面那个节点复制过来。
2、树状数组的使用和一般差不多,只是查询时要处理出哪些线段树加/减。
3、由于值域很大,所以要离线,把修改的值也一起离散化。

嗯大概就是这样了,剩下的细节详见代码:
https://www.luogu.org/paste/3i2fyu82


我们再找一道练习题巩固一下:[CQOI2011]动态逆序对
这题问每次删除一个数前,逆序对的数量。

不难想到,当一个数被删除时,答案减少的量为: 左边比它大的数 + 右边比它小的数。
这是可以用上面类似的做法查询的。

但是这样空间不太够,总共进行了 150000150000 次修改。

不过可以反过来想:先直接算出初始情况的答案,再把删除一个数看成加入一个数,简单容斥一下即可得到答案。

代码:
https://www.luogu.org/paste/33aci5fk

\text{Part 3 —— 可持久化平衡树}Part 3 —— 可持久化平衡树

咕咕咕

\text {Part4 —— 可持久化并查集}Part4 —— 可持久化并查集

有一种比较简单的做法,是用可持久化线段树记录每个点的父亲来实现。

但是这样就不能用路径压缩,而需要改用按秩合并。
所谓按秩合并,也就是把最大深度小的连通块并到最大深度大的上面去,这样可以保证深度为 \text O(\log n)O(logn) 。

既然这样,线段树上还要记录每个点的深度。
现在只需要支持三种操作:

1、更改一个点的父亲
2、增加一个连通块中所有点的深度
3、查询一个点的父亲

第一个很简单,按照普通可持久化的办法搞,深度直接从上一个版本复制过来。
第二个也没什么好说的,看修改的点在左半边还是右半边,递归修改。
最后一个和第二个很像,递归到底层返回当前节点编号即可。

要注意的是,查询时得出的结果,还要作为下标在 \text{father}father 数组中查询才是真正的父节点。

可持久化并查集的 \text{find}find 操作也是一路找父亲上去,只不过不用修改(不用路径压缩)。

合并的时候看最大深度如果不同,直接并上去就好了;否则就把一个块内深度都 +1+1。

时间复杂度 \text O(q \log^2 n)O(qlog2n)

下面是代码:

#include<cstdio>
#include<iostream>
#include<cstring>
#include<cmath>
#include<algorithm>
#define mid ((l+r)>>1)
#define reg register
#define N 200003
#define M 4000003
using namespace std;

inline void read(int &x){
    x = 0;
    char c = getchar();
    while(c<'0'||c>'9') c = getchar();
    while(c>='0'&&c<='9'){
        x = (x<<3)+(x<<1)+(c^48);
        c = getchar();
    }
}

int rt[N],ls[M],rs[M],dep[M],fa[M];
int n,q,cnt;

void build(int &u,int l,int r){
    u = ++cnt;
    if(l==r){
        fa[u] = l; //别忘了并查集初始化,每个点的父亲是自己
        return;
    }
    build(ls[u],l,mid);
    build(rs[u],mid+1,r);
}

void modify(int &u,int pre,int l,int r,int pos,int f){  //将第 pos 个点的父亲变成 f
    u = ++cnt;
    if(l==r){
        fa[u] = f;
        dep[u] = dep[pre];
        return;
    }
    ls[u] = ls[pre],rs[u] = rs[pre];
    if(pos<=mid) modify(ls[u],ls[pre],l,mid,pos,f);
    else modify(rs[u],rs[pre],mid+1,r,pos,f);
}

int query(int u,int l,int r,int pos){
    if(l==r) return u;
    if(pos<=mid) return query(ls[u],l,mid,pos);
    return query(rs[u],mid+1,r,pos);
}

void add(int u,int l,int r,int pos){
    if(l==r){
        ++dep[u];
        return;
    }
    if(pos<=mid) add(ls[u],l,mid,pos);
    else add(rs[u],mid+1,r,pos);
}

int find(int u,int x){ //一路上去找 x 的父亲,查询的线段树根为 u
    int f = query(u,1,n,x);
    if(x==fa[f]) return f;
    return find(u,fa[f]);
}

int main(){
    int op,u,v,fu,fv;
    read(n),read(q);
    build(rt[0],1,n);
    for(reg int i=1;i<=q;++i){
        read(op),read(u);
        if(op!=2) read(v);
        if(op==1){
            rt[i] = rt[i-1]; //先从上一个版本复制
            fu = find(rt[i],u);
            fv = find(rt[i],v);
            if(fa[fu]==fa[fv]) continue; //连通就跳过去
            if(dep[fu]>dep[fv]) swap(fu,fv); //很重要的按秩合并
            modify(rt[i],rt[i-1],1,n,fa[fu],fa[fv]);
            if(dep[fu]==dep[fv]) add(rt[i],1,n,fv);
        }else if(op==2) rt[i] = rt[u];
        else{
            rt[i] = rt[i-1];
            fu = find(rt[i],u);
            fv = find(rt[i],v);
            putchar(fa[fu]==fa[fv]?'1':'0');
            putchar('\n');
        }
    }
    return 0;
}

\text {Part5 —— 可持久化 01Trie}Part5 —— 可持久化 01Trie其实这个东西比可持久化线段树好写不少(虽然本质上是相同的)。。
正常的 Trie 大家应该都会,这里就不讲了(逃

先把例题搬出来:最大异或和
对于这个题,还是很套路地做一下前缀和,这样对于区间的异或和就变成两个值的异或了。
做完前缀和后,查询等价于在 [l-1,r-1][l1,r1] 内找一个位置 pp,使得 \text{a}[p]a[p] 与 \text{a}[n] \text{ xor } xa[n] xor x 的异或值最大。

所以在建立可持久化 Trie 时,也用第 ii 个版本记录 [1,i][1,i] 区间建出的 Trie 树信息,用两个相减得到区间信息。

要实现可持久化,套路地在每次修改的时候新建节点就行了

void insert(int u,int pre,int i,int x){ 
    if(i<0) return; //i表示递归算到第i位
    int t = (x>>i)&1;
    son[u][t^1] = son[pre][t^1];
    son[u][t] = ++cnt; //多出来的,新建节点
    sum[son[u][t]] = sum[son[pre][t]]+1; //对应位1的个数增加
    insert(son[u][t],son[pre][t],i-1,x); 
}

不过上面的这份代码在插入之前,需要新建一下根节点,不然也会出错。
查询也是一样,只不过变成了两个点做差,具体见代码。

#include<cstdio>
#include<iostream>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<ctime>
#define ll long long
#define reg register
#define N 600003
#define M 27654321
using namespace std;

inline void read(int &x){
    x = 0;
    char c = getchar();
    while(c<'0'||c>'9') c = getchar();
    while(c>='0'&&c<='9'){
        x = (x<<3)+(x<<1)+(c^48);
        c = getchar();
    }
}

void print(int x){
    if(x>9) print(x/10);
    putchar(x%10+'0');
}

int a[N],rt[N],son[M][2],sum[M];
int n,q,cnt;

void insert(int u,int pre,int i,int x){
    if(i<0) return;
    int t = (x>>i)&1;
    son[u][t^1] = son[pre][t^1];
    son[u][t] = ++cnt;
    sum[son[u][t]] = sum[son[pre][t]]+1;
    insert(son[u][t],son[pre][t],i-1,x);
}

int query(int u,int v,int i,int x){ //做差得到区间信息,这里移项了一下,效果是一样的
    if(i<0) return 0;
    int t = (x>>i)&1;
    if(sum[son[v][t^1]]>sum[son[u][t^1]]) return (1<<i)|query(son[u][t^1],son[v][t^1],i-1,x);
    return query(son[u][t],son[v][t],i-1,x);
}

int main(){
    rt[0] = ++cnt; 
    insert(rt[0],0,25,0);
    reg int l,r,x;
    read(n),read(q);
    for(reg int i=1;i<=n;++i) read(a[i]);
    for(reg int i=1;i<=n;++i){
        a[i] ^= a[i-1];
        rt[i] = ++cnt;
        insert(rt[i],rt[i-1],25,a[i]);
    }
    while(q--){
        char c = getchar();
        while(c!='A'&&c!='Q') c = getchar();
        if(c=='A'){
            read(x);
            ++n;
            rt[n] = ++cnt;
            a[n] = a[n-1]^x;
            insert(rt[n],rt[n-1],25,a[n]);
        }else{
            read(l),read(r),read(x);
            print(query(l>1?rt[l-2]:0,rt[r-1],25,x^a[n]));
            putchar('\n');
        }
    }
    return 0;
}