【学习笔记】线段树合并
具体地,线段树合并是指合并两棵动态开点权值线段树。
考虑合并过程,将树\(x\)合并到\(y.\)
当当前位置两棵树中有一者为空的时候,可以直接继承当前节点返回。
若到叶子的时候,直接合并信息。
然后对当前树\(x\)的左右叶子合并即可。
int merge(int x,int y,int l,int r){
if(!x||!y)return x|y;
if(l==r){sum[x]+=sum[y];tag[x]=l;return x;}
int mid=(l+r)>>1;
ls[x]=merge(ls[x],ls[y],l,mid);
rs[x]=merge(rs[x],rs[y],mid+1,r);
pushup(x);return x;
}
对于代码中变量的解释:
\(x\)是树中节点的编号。\(sum[x]\)表示\(x\)点处答案的出现次数,\(tag[x]\)表示\(x\)处的答案。
答案记录为出现次数最多的物品编号。
树上更新信息时,考虑左子树优先(即左子树大于等于)来保证编号尽量小。继承的时候\(sum,tag\)要一起更新。
树上修改的时候,由于是单点修改,所以到叶子的时候更新信息,回来的时候\(pushup\)即可。
树上差分:对于路径\(<x,y>,\)我们考虑进行\(cnt[x]+1,cnt[y]+1,cnt[LCA]-1,cnt[fa[LCA]]-1.\)最后进行一次\(dfs\)合并答案。
求\(LCA\)代码中用了倍增。预处理了\(dep,pa,fa.\)
注意理解代码中的\(x,\)里面是数组版本的线段树,所以函数中需要传区间参数\([l,r].\)
注意理解代码中\(x,sum,tag\)的含义,是理解重点。
#include<bits/stdc++.h>
using namespace std;
const int MAXN=6e6+10;
const int N=1e5+10;
int head[N],tot,cnt,n,m;
int ls[MAXN],rs[MAXN],f[N][21];
int dep[N],pa[N],ans[N];
struct E{int nxt,to;}e[N<<1];
inline void add(int x,int y){e[++tot]=(E){head[x],y};head[x]=tot;}
int sum[MAXN],tag[MAXN],rt[MAXN],R=100000;
void dfs(int x,int fa){
pa[x]=f[x][0]=fa;
dep[x]=dep[fa]+1;
for(int i=1;i<=20;++i)f[x][i]=f[f[x][i-1]][i-1];
for(int i=head[x];i;i=e[i].nxt){
int j=e[i].to;
if(j==fa)continue;
dfs(j,x);
}
}
inline void pushup(int x){
if(sum[ls[x]]>=sum[rs[x]])sum[x]=sum[ls[x]],tag[x]=tag[ls[x]];
else sum[x]=sum[rs[x]],tag[x]=tag[rs[x]];
}
int LCA(int x,int y){
if(dep[x]<dep[y])swap(x,y);
for(int i=20;i>=0;--i)
if(dep[f[x][i]]>=dep[y])x=f[x][i];
if(x==y)return x;
for(int i=20;i>=0;--i)
if(f[x][i]!=f[y][i])x=f[x][i],y=f[y][i];
return f[x][0];
}
int change(int x,int l,int r,int pos,int v){
if(!x)x=++cnt;
if(l==r){
sum[x]+=v;
tag[x]=l;
return x;
}
int mid=(l+r)>>1;
if(pos<=mid)ls[x]=change(ls[x],l,mid,pos,v);
else rs[x]=change(rs[x],mid+1,r,pos,v);
pushup(x);return x;
}
int merge(int x,int y,int l,int r){
if(!x||!y)return x|y;
if(l==r){sum[x]+=sum[y];tag[x]=l;return x;}
int mid=(l+r)>>1;
ls[x]=merge(ls[x],ls[y],l,mid);
rs[x]=merge(rs[x],rs[y],mid+1,r);
pushup(x);return x;
}
void redfs(int x){
for(int i=head[x];i;i=e[i].nxt){
int j=e[i].to;
if(j==pa[x])continue;
redfs(j);
rt[x]=merge(rt[x],rt[j],1,R);
}
if(sum[rt[x]])ans[x]=tag[rt[x]];
}
int main(){
scanf("%d%d",&n,&m);
for(int i=1,x,y;i<n;++i){
scanf("%d%d",&x,&y);
add(x,y);add(y,x);
}
dfs(1,0);
for(int i=1,x,y,z;i<=m;++i){
scanf("%d%d%d",&x,&y,&z);
int L=LCA(x,y);
rt[x]=change(rt[x],1,R,z,1);
rt[y]=change(rt[y],1,R,z,1);
rt[L]=change(rt[L],1,R,z,-1);
if(pa[L])rt[pa[L]]=change(rt[pa[L]],1,R,z,-1);
}
redfs(1);
for(int i=1;i<=n;++i)printf("%d\n",ans[i]);
return 0;
}