线段树分裂与合并
线段树合并与分裂
注意:下面的操作基本上是对权值平衡树的操作,换句话说,线段树合并于分裂大部分是在权值线段树上运用的。
1 线段树合并
我们需要把两颗线段树合并,怎么做?首先可以把一棵线段树的值一个一个加入到另一颗线段树中,但是这样的复杂度是 \(O(n\log n)\),但是这样不优美,我们考虑把两颗线段树的对应位置相加,像这样:
因为一共有最多有 \(n\log n\) 个节点,所以最坏情况下复杂度也是 \(O(n\log n)\) ,但是这种方法在通常情况下比上面要优,这是因为这颗二叉树在绝大多数情况下都不是一个满二叉树,而后面这种方法在遇到空子树的时候就会停下来。
代码:
inline void merge(int &a,int b,int l,int r){
if(!a||!b){a=a+b;return;}
if(l==r){p[a].sum+=p[b].sum;del(b);return;}
int mid=(l+r)>>1;
merge(p[a].l,p[b].l,l,mid);
merge(p[a].r,p[b].r,mid+1,r);
pushup(a);del(b);
}
比较好理解,这里不做讲解。注意这里的 del
函数是垃圾回收。在第二道例题中会有用处。
2 线段树分裂
像这样:
我们下面的程序中,是将以 \(a\) 为根的线段树中保留排名为 \(1\) 到 \(k\) 中的数而把其他值给以 \(b\) 为根的线段树中。
inline void split(int a,int &b,int k){
if(!a) return;
b=new_node();
int v=p[p[a].l].sum;
if(v<k) split(p[a].r,p[b].r,k-v);
else swap(p[a].r,p[b].r);
if(v>k) split(p[a].l,p[b].l,k);
p[b].sum=p[a].sum-k;p[a].sum=k;
}
注意:不要弄混该函数分裂的结果,是保留 \(1\) 到 \(n\) 而不是把 \(1\) 到 \(n\) 分裂出去。
上面这段代码什么意思?注意,有可能 \(b\) 的相应位置并没有节点,所以我们在第 \(3\) 行需要给他动态开点,而第 \(5\) 行是说,如果有 \(v<k\) ,即左子树的大小比 \(k\) 要小,说明 \(b\) 并没有拿走左子树,而只拿走了右子树的一部分。
否则,也就是第 \(6\) 行,不管是相等还是大于,右子树都要被分裂出去,所以才有了第 \(6\) 行的交换,然后第 \(7\) 行是看左子树有没有必要进行分裂。注意因为有动态开点,所以要给 \(b\) 加引用符号。
注意:无论是线段树合并还是线段树分裂,都要注意在主函数引用的时候,参数的顺序,合并中,合并到的位置是第一个参数,分裂中,分裂出的参数是第二个。
3 例题
3.1 P4556 [Vani有约会]雨天的尾巴 /【模板】线段树合并
我们对每一个节点建一棵权值线段树,考虑树上差分,因为每一次操作只对 \(4\) 个节点,所以如果动态开点的话,时间和空间开销远远到不了上限,所以可以做。
查分后从下往上合并线段树,因为节点个数有限,在合并中不会加入新节点,所以复杂度不会太高。
这里查询 lca 用的是轻重链剖分。
#include<bits/stdc++.h>
#define dd double
#define ld long double
#define ll long long
#define uint unsigned int
#define ull unsigned long long
#define N 100100
#define M 6000100
using namespace std;
const int INF=0x3f3f3f3f;
template<typename T> inline void read(T &x) {
x=0; int f=1;
char c=getchar();
for(;!isdigit(c);c=getchar()) if(c == '-') f=-f;
for(;isdigit(c);c=getchar()) x=x*10+c-'0';
x*=f;
}
template<typename T> inline T Max(T a,T b){
return a>b?a:b;
}
struct edge{
int to,next;
inline void intt(int to_,int ne_){
to=to_;next=ne_;
}
};
edge li[N<<1];
int head[N],tail;
inline void add(int from,int to){
li[++tail].intt(to,head[from]);
head[from]=tail;
}
int n,m;
struct node{
int l,r,max_val,max_posi;
};
node p[M<<2];
int tot,root[M],max_right;
int top[M],siz[M],son[M],fa[M],deep[M],ans[M];
inline void dfs1(int k,int fat){
fa[k]=fat;deep[k]=deep[fat]+1;siz[k]=1;
for(int x=head[k];x;x=li[x].next){
int to=li[x].to;
if(to==fat) continue;
dfs1(to,k);
siz[k]+=siz[to];
if(siz[son[k]]<siz[to]) son[k]=to;
}
}
inline void dfs2(int k,int t){
top[k]=t;
if(son[k]) dfs2(son[k],t);
for(int x=head[k];x;x=li[x].next){
int to=li[x].to;
if(to==fa[k]||to==son[k]) continue;
dfs2(to,to);
}
}
inline void pushup(int k){
if(p[p[k].l].max_val>=p[p[k].r].max_val){
p[k].max_val=p[p[k].l].max_val;
p[k].max_posi=p[p[k].l].max_posi;
}
else{
p[k].max_val=p[p[k].r].max_val;
p[k].max_posi=p[p[k].r].max_posi;
}
}
inline int find_lca(int a,int b){
while(top[a]!=top[b]){
if(deep[top[a]]<deep[top[b]]) swap(a,b);
a=fa[top[a]];
}
if(deep[a]>deep[b]) swap(a,b);
return a;
}
inline int new_node(){
tot++;return tot;
}
inline void change(int &k,int l,int r,int x,int val){
if(!k) k=new_node();
if(l==r&&x==l){
p[k].max_val+=val;p[k].max_posi=l;
return;
}
int mid=(l+r)>>1;
if(x<=mid) change(p[k].l,l,mid,x,val);
else change(p[k].r,mid+1,r,x,val);
pushup(k);
}
inline void merge(int &a,int b,int l,int r){
if(!a||!b){a=a+b;return;}
if(l==r){
p[a].max_val+=p[b].max_val;
p[a].max_posi=l;return;
}
int mid=(l+r)>>1;
merge(p[a].l,p[b].l,l,mid);
merge(p[a].r,p[b].r,mid+1,r);
pushup(a);
}
inline void solve(int k){
for(int x=head[k];x;x=li[x].next){
int to=li[x].to;
if(to==fa[k]) continue;
solve(to);
merge(root[k],root[to],1,max_right);
}
if(p[root[k]].max_val) ans[k]=p[root[k]].max_posi;
}
struct ques{
int a,b,c;
inline void intt(int a_,int b_,int c_){
a=a_;b=b_;c=c_;
}
};
ques qu[N];
int main(){
read(n);read(m);
for(int i=1;i<=n-1;i++){
int from,to;
read(from);read(to);
add(from,to);add(to,from);
}
dfs1(1,0);dfs2(1,1);
for(int i=1;i<=m;i++){
int a,b,c;read(a);read(b);read(c);
qu[i].intt(a,b,c);max_right=Max(max_right,qu[i].c);
}
for(int i=1;i<=m;i++){
int lca=find_lca(qu[i].a,qu[i].b);
change(root[qu[i].a],1,max_right,qu[i].c,1);
change(root[qu[i].b],1,max_right,qu[i].c,1);
change(root[lca],1,max_right,qu[i].c,-1);
if(fa[lca]) change(root[fa[lca]],1,max_right,qu[i].c,-1);
}
solve(1);
for(int i=1;i<=n;i++) printf("%d\n",ans[i]);
return 0;
}
3.2 线段树分裂模板
容易发现,每一个可重集其实都是一颗权值线段树。
操作 \(2\) 是线段树单点修改,操作 \(3\) 是线段树加法,操作 \(4\) 是在权值线段树上二分,操作 \(1\) 就是将线段树合并,那么操作 \(0\) 怎么做?考虑我们可以线段树分裂来做,注意线段树分裂依据的是排名,所以我们首先要查找一下\([1,x-1]\) 中的数有多少,记为 \(num_1\) ,然后再查找 \([x,y]\) 中的数有多少,记为 \(num_2\) 。我们先对第一棵线段树按照 \(num_1\) 进行分裂,存到新的一棵线段树中去,然后对这颗新的线段树按照 \(num_2\) 分裂,分裂出来的再与第一棵线段树合并,这个题就做完了。
代码:
#include<bits/stdc++.h>
#define dd double
#define ld long double
#define ll long long
#define int long long
#define uint unsigned int
#define ull unsigned long long
#define N 200000
#define M number
using namespace std;
const int INF=0x3f3f3f3f;
template<typename T> inline void read(T &x) {
x=0; int f=1;
char c=getchar();
for(;!isdigit(c);c=getchar()) if(c == '-') f=-f;
for(;isdigit(c);c=getchar()) x=x*10+c-'0';
x*=f;
}
struct node{
int sum,l,r;
node() {}
node(int sum,int l,int r) : sum(sum),l(l),r(r) {}
};
node p[N<<4];
int root[N],roottail,n,m,delq[N],deltail,tot;
inline int new_node(){
return deltail?delq[deltail--]:++tot;
}
inline void pushup(int k){
p[k].sum=p[p[k].l].sum+p[p[k].r].sum;
}
inline void change(int &k,int l,int r,int x,int val){
if(!k) k=new_node();
if(l==r){
p[k].sum+=val;return;
}
int mid=(l+r)>>1;
if(x<=mid) change(p[k].l,l,mid,x,val);
else change(p[k].r,mid+1,r,x,val);
pushup(k);
}
inline void del(int k){
p[k].sum=0;p[k].l=p[k].r=0;
delq[++deltail]=k;
}
inline void merge(int &a,int b,int l,int r){
if(!a||!b){a=a+b;return;}
if(l==r){p[a].sum+=p[b].sum;del(b);return;}
int mid=(l+r)>>1;
merge(p[a].l,p[b].l,l,mid);
merge(p[a].r,p[b].r,mid+1,r);
pushup(a);del(b);
}
inline void split(int a,int &b,int k){
if(!a) return;
b=new_node();
int v=p[p[a].l].sum;
if(v<k) split(p[a].r,p[b].r,k-v);
else swap(p[a].r,p[b].r);
if(v>k) split(p[a].l,p[b].l,k);
p[b].sum=p[a].sum-k;p[a].sum=k;
}
inline int ask_sum(int k,int l,int r,int z,int y){
if(l==z&&r==y) return p[k].sum;
int mid=(l+r)>>1;
if(y<=mid) return ask_sum(p[k].l,l,mid,z,y);
else if(z>mid) return ask_sum(p[k].r,mid+1,r,z,y);
else return ask_sum(p[k].l,l,mid,z,mid)+ask_sum(p[k].r,mid+1,r,mid+1,y);
}
inline int get_val(int k,int l,int r,int rank){
if(l==r) return l;
int mid=(l+r)>>1;
if(p[p[k].l].sum<rank) return get_val(p[k].r,mid+1,r,rank-p[p[k].l].sum);
else return get_val(p[k].l,l,mid,rank);
}
signed main(){
read(n);read(m);roottail=1;
for(int i=1;i<=n;i++){
int val;read(val);
change(root[1],1,n,i,val);
}
for(int i=1;i<=m;i++){
int op;read(op);
if(op==0){
int P,x,y,now;read(P);read(x);read(y);
int num1=ask_sum(root[P],1,n,1,y);
int num2=ask_sum(root[P],1,n,x,y);
split(root[P],root[++roottail],num1-num2);split(root[roottail],now,num2);
merge(root[P],now,1,n);
}
else if(op==1){
int P,t;read(P);read(t);
merge(root[P],root[t],1,n);
}
else if(op==2){
int P,x,q;read(P);read(x);read(q);
change(root[P],1,n,q,x);
}
else if(op==3){
int P,x,y;read(P);read(x);read(y);
printf("%lld\n",ask_sum(root[P],1,n,x,y));
}
else if(op==4){
int P,rank;read(P);read(rank);
if(rank<0||p[root[P]].sum<rank) printf("-1\n");
else printf("%lld\n",get_val(root[P],1,n,rank));
}
}
return 0;
}
这个题我们用到了垃圾回收,我们开一个栈保存我们已经删除的节点编号,然后新建节点时有限使用这些已经被删除的编号,这就是垃圾回收,用来卡空间。