Loading

线段树分裂与合并

线段树合并与分裂

注意:下面的操作基本上是对权值平衡树的操作,换句话说,线段树合并于分裂大部分是在权值线段树上运用的。

1 线段树合并

我们需要把两颗线段树合并,怎么做?首先可以把一棵线段树的值一个一个加入到另一颗线段树中,但是这样的复杂度是 \(O(n\log n)\),但是这样不优美,我们考虑把两颗线段树的对应位置相加,像这样:

因为一共有最多有 \(n\log n\) 个节点,所以最坏情况下复杂度也是 \(O(n\log n)\) ,但是这种方法在通常情况下比上面要优,这是因为这颗二叉树在绝大多数情况下都不是一个满二叉树,而后面这种方法在遇到空子树的时候就会停下来。

代码:

inline void merge(int &a,int b,int l,int r){
    if(!a||!b){a=a+b;return;}
    if(l==r){p[a].sum+=p[b].sum;del(b);return;}
    int mid=(l+r)>>1;
    merge(p[a].l,p[b].l,l,mid);
    merge(p[a].r,p[b].r,mid+1,r);
    pushup(a);del(b);
}

比较好理解,这里不做讲解。注意这里的 del 函数是垃圾回收。在第二道例题中会有用处。

2 线段树分裂

像这样:

我们下面的程序中,是将以 \(a\) 为根的线段树中保留排名为 \(1\)\(k\) 中的数而把其他值给以 \(b\) 为根的线段树中。

inline void split(int a,int &b,int k){
    if(!a) return;
    b=new_node();
    int v=p[p[a].l].sum;
    if(v<k) split(p[a].r,p[b].r,k-v);
    else swap(p[a].r,p[b].r);
    if(v>k) split(p[a].l,p[b].l,k);
    p[b].sum=p[a].sum-k;p[a].sum=k;
}

注意:不要弄混该函数分裂的结果,是保留 \(1\)\(n\) 而不是把 \(1\)\(n\) 分裂出去。

上面这段代码什么意思?注意,有可能 \(b\) 的相应位置并没有节点,所以我们在第 \(3\) 行需要给他动态开点,而第 \(5\) 行是说,如果有 \(v<k\) ,即左子树的大小比 \(k\) 要小,说明 \(b\) 并没有拿走左子树,而只拿走了右子树的一部分。

否则,也就是第 \(6\) 行,不管是相等还是大于,右子树都要被分裂出去,所以才有了第 \(6\) 行的交换,然后第 \(7\) 行是看左子树有没有必要进行分裂。注意因为有动态开点,所以要给 \(b\) 加引用符号。

注意:无论是线段树合并还是线段树分裂,都要注意在主函数引用的时候,参数的顺序,合并中,合并到的位置是第一个参数,分裂中,分裂出的参数是第二个。

3 例题

3.1 P4556 [Vani有约会]雨天的尾巴 /【模板】线段树合并

链接

我们对每一个节点建一棵权值线段树,考虑树上差分,因为每一次操作只对 \(4\) 个节点,所以如果动态开点的话,时间和空间开销远远到不了上限,所以可以做。

查分后从下往上合并线段树,因为节点个数有限,在合并中不会加入新节点,所以复杂度不会太高。

这里查询 lca 用的是轻重链剖分。

#include<bits/stdc++.h>
#define dd double
#define ld long double
#define ll long long
#define uint unsigned int
#define ull unsigned long long
#define N 100100
#define M 6000100
using namespace std;

const int INF=0x3f3f3f3f;

template<typename T> inline void read(T &x) {
    x=0; int f=1;
    char c=getchar();
    for(;!isdigit(c);c=getchar()) if(c == '-') f=-f;
    for(;isdigit(c);c=getchar()) x=x*10+c-'0';
    x*=f;
}

template<typename T> inline T Max(T a,T b){
    return a>b?a:b;
}

struct edge{
    int to,next;
    inline void intt(int to_,int ne_){
        to=to_;next=ne_;
    }
};
edge li[N<<1];
int head[N],tail;

inline void add(int from,int to){
    li[++tail].intt(to,head[from]);
    head[from]=tail;
}

int n,m;

struct node{
    int l,r,max_val,max_posi;
};
node p[M<<2];

int tot,root[M],max_right;
int top[M],siz[M],son[M],fa[M],deep[M],ans[M];

inline void dfs1(int k,int fat){
    fa[k]=fat;deep[k]=deep[fat]+1;siz[k]=1;
    for(int x=head[k];x;x=li[x].next){
        int to=li[x].to;
        if(to==fat) continue;
        dfs1(to,k);
        siz[k]+=siz[to];
        if(siz[son[k]]<siz[to]) son[k]=to;
    }
}

inline void dfs2(int k,int t){
    top[k]=t;
    if(son[k]) dfs2(son[k],t);
    for(int x=head[k];x;x=li[x].next){
        int to=li[x].to;
        if(to==fa[k]||to==son[k]) continue;
        dfs2(to,to);
    }
}

inline void pushup(int k){
    if(p[p[k].l].max_val>=p[p[k].r].max_val){
        p[k].max_val=p[p[k].l].max_val;
        p[k].max_posi=p[p[k].l].max_posi;
    }
    else{
        p[k].max_val=p[p[k].r].max_val;
        p[k].max_posi=p[p[k].r].max_posi;
    }
}

inline int find_lca(int a,int b){
    while(top[a]!=top[b]){
        if(deep[top[a]]<deep[top[b]]) swap(a,b);
        a=fa[top[a]];
    }
    if(deep[a]>deep[b]) swap(a,b);
    return a;
}

inline int new_node(){
    tot++;return tot;
}

inline void change(int &k,int l,int r,int x,int val){
    if(!k) k=new_node();
    if(l==r&&x==l){
        p[k].max_val+=val;p[k].max_posi=l;
        return;
    }
    int mid=(l+r)>>1;
    if(x<=mid) change(p[k].l,l,mid,x,val);
    else change(p[k].r,mid+1,r,x,val);
    pushup(k);
}

inline void merge(int &a,int b,int l,int r){
    if(!a||!b){a=a+b;return;}
    if(l==r){
        p[a].max_val+=p[b].max_val;
        p[a].max_posi=l;return;
    }
    int mid=(l+r)>>1;
    merge(p[a].l,p[b].l,l,mid);
    merge(p[a].r,p[b].r,mid+1,r);
    pushup(a);
}

inline void solve(int k){
    for(int x=head[k];x;x=li[x].next){
        int to=li[x].to;
        if(to==fa[k]) continue;
        solve(to);
        merge(root[k],root[to],1,max_right);
    }
    if(p[root[k]].max_val) ans[k]=p[root[k]].max_posi;
}

struct ques{
    int a,b,c;
    inline void intt(int a_,int b_,int c_){
        a=a_;b=b_;c=c_;
    }
};
ques qu[N];

int main(){
    read(n);read(m);
    for(int i=1;i<=n-1;i++){
        int from,to;
        read(from);read(to);
        add(from,to);add(to,from);
    }
    dfs1(1,0);dfs2(1,1);
    for(int i=1;i<=m;i++){
        int a,b,c;read(a);read(b);read(c);
        qu[i].intt(a,b,c);max_right=Max(max_right,qu[i].c);
    }
    for(int i=1;i<=m;i++){
        int lca=find_lca(qu[i].a,qu[i].b);
        change(root[qu[i].a],1,max_right,qu[i].c,1);
        change(root[qu[i].b],1,max_right,qu[i].c,1);
        change(root[lca],1,max_right,qu[i].c,-1);
        if(fa[lca]) change(root[fa[lca]],1,max_right,qu[i].c,-1);
    }
    solve(1);
    for(int i=1;i<=n;i++) printf("%d\n",ans[i]);
    return 0;
}

3.2 线段树分裂模板

容易发现,每一个可重集其实都是一颗权值线段树。

操作 \(2\) 是线段树单点修改,操作 \(3\) 是线段树加法,操作 \(4\) 是在权值线段树上二分,操作 \(1\) 就是将线段树合并,那么操作 \(0\) 怎么做?考虑我们可以线段树分裂来做,注意线段树分裂依据的是排名,所以我们首先要查找一下\([1,x-1]\) 中的数有多少,记为 \(num_1\) ,然后再查找 \([x,y]\) 中的数有多少,记为 \(num_2\) 。我们先对第一棵线段树按照 \(num_1\) 进行分裂,存到新的一棵线段树中去,然后对这颗新的线段树按照 \(num_2\) 分裂,分裂出来的再与第一棵线段树合并,这个题就做完了。

代码:

#include<bits/stdc++.h>
#define dd double
#define ld long double
#define ll long long
#define int long long
#define uint unsigned int
#define ull unsigned long long
#define N 200000
#define M number
using namespace std;

const int INF=0x3f3f3f3f;

template<typename T> inline void read(T &x) {
    x=0; int f=1;
    char c=getchar();
    for(;!isdigit(c);c=getchar()) if(c == '-') f=-f;
    for(;isdigit(c);c=getchar()) x=x*10+c-'0';
    x*=f;
}

struct node{
    int sum,l,r;
    node() {}
    node(int sum,int l,int r) : sum(sum),l(l),r(r) {}
};
node p[N<<4];

int root[N],roottail,n,m,delq[N],deltail,tot;

inline int new_node(){
    return deltail?delq[deltail--]:++tot;
}

inline void pushup(int k){
    p[k].sum=p[p[k].l].sum+p[p[k].r].sum;
}

inline void change(int &k,int l,int r,int x,int val){
    if(!k) k=new_node();
    if(l==r){
        p[k].sum+=val;return;
    }
    int mid=(l+r)>>1;
    if(x<=mid) change(p[k].l,l,mid,x,val);
    else change(p[k].r,mid+1,r,x,val);
    pushup(k);
}

inline void del(int k){
    p[k].sum=0;p[k].l=p[k].r=0;
    delq[++deltail]=k;
}

inline void merge(int &a,int b,int l,int r){
    if(!a||!b){a=a+b;return;}
    if(l==r){p[a].sum+=p[b].sum;del(b);return;}
    int mid=(l+r)>>1;
    merge(p[a].l,p[b].l,l,mid);
    merge(p[a].r,p[b].r,mid+1,r);
    pushup(a);del(b);
}

inline void split(int a,int &b,int k){
    if(!a) return;
    b=new_node();
    int v=p[p[a].l].sum;
    if(v<k) split(p[a].r,p[b].r,k-v);
    else swap(p[a].r,p[b].r);
    if(v>k) split(p[a].l,p[b].l,k);
    p[b].sum=p[a].sum-k;p[a].sum=k;
}

inline int ask_sum(int k,int l,int r,int z,int y){
    if(l==z&&r==y) return p[k].sum;
    int mid=(l+r)>>1;
    if(y<=mid) return ask_sum(p[k].l,l,mid,z,y);
    else if(z>mid) return ask_sum(p[k].r,mid+1,r,z,y);
    else return ask_sum(p[k].l,l,mid,z,mid)+ask_sum(p[k].r,mid+1,r,mid+1,y);
}

inline int get_val(int k,int l,int r,int rank){
    if(l==r) return l;
    int mid=(l+r)>>1;
    if(p[p[k].l].sum<rank) return get_val(p[k].r,mid+1,r,rank-p[p[k].l].sum);
    else return get_val(p[k].l,l,mid,rank);
}

signed main(){
    read(n);read(m);roottail=1;
    for(int i=1;i<=n;i++){
        int val;read(val);
        change(root[1],1,n,i,val);
    }
    for(int i=1;i<=m;i++){
        int op;read(op);
        if(op==0){
            int P,x,y,now;read(P);read(x);read(y);
            int num1=ask_sum(root[P],1,n,1,y);
            int num2=ask_sum(root[P],1,n,x,y);
            split(root[P],root[++roottail],num1-num2);split(root[roottail],now,num2);
            merge(root[P],now,1,n);
        }
        else if(op==1){
            int P,t;read(P);read(t);
            merge(root[P],root[t],1,n);
        }
        else if(op==2){
            int P,x,q;read(P);read(x);read(q);
            change(root[P],1,n,q,x);
        }
        else if(op==3){
            int P,x,y;read(P);read(x);read(y);
            printf("%lld\n",ask_sum(root[P],1,n,x,y));
        }
        else if(op==4){
            int P,rank;read(P);read(rank);
            if(rank<0||p[root[P]].sum<rank) printf("-1\n");
            else printf("%lld\n",get_val(root[P],1,n,rank));
        }
    }
    return 0;
}

这个题我们用到了垃圾回收,我们开一个栈保存我们已经删除的节点编号,然后新建节点时有限使用这些已经被删除的编号,这就是垃圾回收,用来卡空间。

posted @ 2021-06-28 15:56  hyl天梦  阅读(390)  评论(3编辑  收藏  举报