线段树合并学习笔记
线段树合并
过程:
顾名思义,线段树合并是指建立一棵新的线段树,这棵线段树的每个节点都是两棵原线段树对应节点合并后的结果。它常常被用于维护树上或是图上的信息。
一般每个点建一棵线段树,以子树或者题目要求进行合并(比如连通块)。
实现:
我们考虑每次递归合并。把线段树 \(b\) 上的信息和线段树 \(a\) 上的信息合并更新。并且递归下去。如果某个线段树节点为空则直接返回。
例题:
P4556 [Vani有约会] 雨天的尾巴 /【模板】线段树合并
题意
给一棵树,一共 \(m\) 次操作,操作点对 \((x,y,z)\) 表示在 \(x\to y\) 的路径上每个点对于 \(x\) 种类计数器 \(+1\)。最后请输出对于这棵树上每个点计数器值最大的种类。
解法
我们考虑一个暴力的想法,即我们对于 \(x\to1\) 的路径上 \(cnt_{z}+1\),对于 \(y\to 1\) 的路径上 \(cnt_z+1\)。对于 \(lca(x,y) \to 1\) 的路径上 \(cnt_z-1\),对于\(fa_{lca(x,y)}\to 1\) 的路径上 \(cnt_z-1\)。建 \(n\) 棵线段树,区间加,单点查。但是对于这道题无法通过。
我们考虑树上差分,即对于这样一个询问,我们仅对于 \(x\) 的 \(cnt_{z}+1\),对于 \(y\) 的 \(cnt_z+1\)。对于 \(lca(x,y)\) 的 \(cnt_z-1\),对于$fa_{lca(x,y)} $ 的 \(cnt_z-1\)。这样答案可以通过子树和方式求得。
说干就干,但是子树和如果通过数组按位相加再比较实在是太慢了!所以我们考虑线段树合并,我们把答案子树 \(u\) 内答案都合并到 \(u\) 这棵线段树上来,然后直接查答案即可。
这题线段树需要动态开点,否则空间不够。
\(lca\) 使用树剖维护即可。
复杂度 \(O(n\log n+m\log n)\),足够通过本题。
代码
#include<bits/stdc++.h>
using namespace std;
const int N=100000;
#define mid ((l+r)>>1)
int n,rt[100005];
int sum[5000005],cnt=0,res[5000005],ls[5000005],rs[5000005];
int m,ans[100005];
vector<int> G[100005];
//sum 出现次数,res种类--也就是答案
struct tree{
int top[5000005];
int fa[5000005],bgs[5000005],dep[5000005],siz[5000005];
void dfs(int x,int fat){
dep[x]=dep[fat]+1;
fa[x]=fat;siz[x]=1;
for(int k:G[x]){
if(k!=fat){
dfs(k,x);
siz[x]+=siz[k];
if(siz[k]>siz[bgs[x]]) bgs[x]=k;
}
}
}
void DFS(int x,int fat,int tp){
top[x]=tp;
if(bgs[x]){
DFS(bgs[x],x,tp);
}
for(int k:G[x]){
if(k!=fat&&k!=bgs[x]) DFS(k,x,k);
}
}
int lca(int x,int y){
while(top[x]^top[y]) dep[top[x]]>dep[top[y]]?x=fa[top[x]]:y=fa[top[y]];
return dep[x]<dep[y]?x:y;
}
}_sp;
struct seg_tree{
int merge(int a,int b,int l,int r){
if(!a) return b;
if(!b) return a;
if(l==r) return sum[a]+=sum[b],a;
ls[a]=merge(ls[a],ls[b],l,mid),rs[a]=merge(rs[a],rs[b],mid+1,r);
pushup(a,ls[a],rs[a]);
return a;
}
void pushup(int k,int l,int r){
if(sum[ls[k]]<sum[rs[k]]) res[k]=res[rs[k]],sum[k]=sum[rs[k]];
else res[k]=res[ls[k]],sum[k]=sum[ls[k]];
}
int modify(int k,int l,int r,int co,int val){
if(!k) k=++cnt;
if(l==r) return sum[k]+=val,res[k]=co,k;
if(co<=mid) ls[k]=modify(ls[k],l,mid,co,val);
else rs[k]=modify(rs[k],mid+1,r,co,val);
pushup(k,l,r);
return k;
}
void get_ans(int x){
for(int k:G[x]){
if(k==_sp.fa[x]) continue;
get_ans(k);
rt[x]=merge(rt[x],rt[k],1,100000);
}
ans[x]=res[rt[x]];
if(sum[rt[x]]==0) ans[x]=0;
}
}T;
void read(){
scanf("%d%d",&n,&m);
for(int i=1;i<n;i++){
int u,v;scanf("%d%d",&u,&v);
G[u].push_back(v),G[v].push_back(u);
}
_sp.dfs(1,0),_sp.DFS(1,0,1);
}
void solve(){
for(int i=1;i<=m;i++){
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
rt[x]=T.modify(rt[x],1,N,z,1),rt[y]=T.modify(rt[y],1,N,z,1);
int _lca=_sp.lca(x,y);
// printf("lca:%d\n",_lca);
rt[_lca]=T.modify(rt[_lca],1,N,z,-1);
rt[_sp.fa[_lca]]=T.modify(rt[_sp.fa[_lca]],1,N,z,-1);
}
// return;
T.get_ans(1);
for(int i=1;i<=n;i++) printf("%d\n",ans[i]);
}
int main(){
// freopen("data.in","r",stdin);
read();
// return 0;
solve();
return 0;
}
P3224 [HNOI2012] 永无乡
题意
给定一个图,每个点都有一个重要度。两种操作,操作一连接两个点,操作二求对于某个点连通的所有点中排名第 \(y\) 小的是哪个。
解法
考虑用并查集维护连通性。对于每个连通块用线段树维护重要度。
-
对于每次联通操作,先并查集合并,再线段树合并。
-
对于每次查询,我们在所对应的连通块所对应的线段树上二分查找即可。
复杂度 \(O(n\log n)\),足以通过本题。线段树要动态开点。注意合并时最后要 return a
。
代码
#include<bits/stdc++.h>
using namespace std;
#define mid ((l+r)>>1)
const int N=1e5+7;
const int M=3e6+7;
int fa[N];
int n,m,q;
int rt[N],cnt,ls[M],rs[M],rnk[M],sum[M];
int x,y;
struct node{
int find(int x){return fa[x]==x?x:fa[x]=find(fa[x]);}
}bcj;
struct seg_tree{
void pushup(int k,int l,int r){
sum[k]=sum[ls[k]]+sum[rs[k]];
}
int modify(int k,int l,int r,int pos,int idx){
if(!k) k=++cnt;
if(l==r) return rnk[k]=idx,sum[k]++,k;
if(pos<=mid) ls[k]=modify(ls[k],l,mid,pos,idx);
else rs[k]=modify(rs[k],mid+1,r,pos,idx);
pushup(k,l,r);
return k;
}
int query(int k,int l,int r,int val){
if(sum[k]<val||!k) return 0;
if(l==r) return rnk[k];
if(val<=sum[ls[k]]) return query(ls[k],l,mid,val);
else return query(rs[k],mid+1,r,val-sum[ls[k]]);
}
int merge(int a,int b,int l,int r){
if(!a) return b;
if(!b) return a;
if(l==r){
if(rnk[b]) rnk[a]=rnk[b],sum[a]+=sum[b];
return a;
}
ls[a]=merge(ls[a],ls[b],l,mid);
rs[a]=merge(rs[a],rs[b],mid+1,r);
pushup(a,l,r);
return a;
}
}T;
int main(){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++){
fa[i]=i;int x;scanf("%d",&x);
rt[i]=T.modify(rt[i],1,n,x,i);
}
for(int i=1;i<=m;i++){
int x,y;scanf("%d%d",&x,&y);
x=bcj.find(x),y=bcj.find(y);
fa[y]=x;
rt[x]=T.merge(rt[x],rt[y],1,n);
}
char s[5];
scanf("%d",&q);
while(q--){
scanf("%s",s);scanf("%d%d",&x,&y);
if(s[0]=='B'){
x=bcj.find(x),y=bcj.find(y);
if(x==y) continue;
fa[y]=x;
rt[x]=T.merge(rt[x],rt[y],1,n);
}
else{
x=bcj.find(x);
int ans=T.query(rt[x],1,n,y);
if(!ans) printf("-1\n");
else printf("%d\n",ans);
}
}
return 0;
}
P3899 [湖南集训] 更为厉害
题意
求对于一个点编号为 \(p\) 的点 \(a\),求树上有多少三元组 \((a,b,c)\) 满足 \(a,b\) 是 \(c\) 的祖先,并且 \(dis(a,b)\le k\)。(\(k\) 为给定的常数)。
分析
我们考虑分类讨论。由题意可得,\(a,b,c\) 一定在一条链上。
- 若 \(b\) 在 \(a\) 上方,则最多可以跳到 \(dep_a-k\) 处(不跳出去的情况下)。因而此时答案为 \((size_a-1)\times \min(dep_a-1,k)\)。
- 若 \(b\) 在 \(a\) 下方,那么 \(b\) 可以在 \(a\) 子树深度为 \([deep_a+1,deep_a+k]\) 范围内,此时 \(c\) 的数量就是 \(size_b-1\),因而我们考虑建立权值线段树,求下标为 \([deep_a+1,deep_a+k]\) 区间内 \(size-1\) 的和即可。对于每个点维护每个深度的答案即可,考虑子树内所有点所在线段树需要向该子树根节点做线段树合并。
代码
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define mid ((l+r)>>1)
const int N=3e5+7;
int n,q,rt[N],cnt,dep[N],sz[N];
int ans[N];
vector<int> G[N];
vector<pair<int,int> > Q[N];
int ls[N*10<<2],rs[N*10<<2],sum[N*10<<2];
struct seg_tree{
void pushup(int k,int l,int r){
sum[k]=sum[ls[k]]+sum[rs[k]];
}
int modify(int k,int l,int r,int x,int val){
if(!k) k=++cnt;
if(l==r) return (sum[k]+=val),k;
if(x<=mid) ls[k]=modify(ls[k],l,mid,x,val);
else rs[k]=modify(rs[k],mid+1,r,x,val);
pushup(k,l,r);
return k;
}
int merge(int a,int b,int l,int r){
if(!a) return b;if(!b) return a;
if(l==r) return (sum[a]+=sum[b]),a;
ls[a]=merge(ls[a],ls[b],l,mid);
rs[a]=merge(rs[a],rs[b],mid+1,r);
pushup(a,l,r);
return a;
}
int query(int k,int l,int r,int x,int y){
if(!k) return 0;
if(x<=l&&y>=r) return sum[k];
int res=0;
if(x<=mid) res+=query(ls[k],l,mid,x,y);
if(y>=mid+1) res+=query(rs[k],mid+1,r,x,y);
return res;
}
}T;
void DFS(int x,int fa){
dep[x]=dep[fa]+1,sz[x]=1;
for(int k:G[x]){
if(k==fa) continue;
DFS(k,x);
rt[x]=T.merge(rt[x],rt[k],1,n);
sz[x]+=sz[k];
}
rt[x]=T.modify(rt[x],1,n,dep[x],sz[x]-1);
for(auto it:Q[x]){
int id=it.first,k=it.second;
ans[id]=T.query(rt[x],1,n,min(n,dep[x]+1),min(dep[x]+k,n))+(sz[x]-1)*min(dep[x]-1,k);
}
}
signed main(){
scanf("%lld%lld",&n,&q);
for(int i=1;i<n;i++){
int u,v;scanf("%lld%lld",&u,&v);
G[u].push_back(v),G[v].push_back(u);
}
for(int i=1;i<=q;i++){
int x,k;scanf("%lld%lld",&x,&k);
Q[x].push_back(make_pair(i,k));
}
DFS(1,0);
for(int i=1;i<=q;i++) printf("%lld\n",ans[i]);
return 0;
}
P3605 [USACO17JAN] Promotion Counting P
分析
板子。
代码
#include<bits/stdc++.h>
using namespace std;
#define mid ((l+r)>>1)
const int N=1e5+7;
int n,m;
int a[N],b[N];
int fa[N];
int data[N*40];
int ls[N*40],rs[N*40];
int tot;
int rt[N*40];
int ans[N];
vector<int> G[N];
struct seg_tree{
void pushup(int k){data[k]=data[ls[k]]+data[rs[k]];}
int insert(int k,int l,int r,int x){
if(!k) k=++tot;
if(l==r) return data[k]++,k;
if(x<=mid) ls[k]=insert(ls[k],l,mid,x);
else rs[k]=insert(rs[k],mid+1,r,x);
pushup(k);
return k;
}
int ask(int k,int l,int r,int x,int y){
if(x<=l&&y>=r) return data[k];
int res=0;
if(x<=mid) res+=ask(ls[k],l,mid,x,y);
if(y>=mid+1) res+=ask(rs[k],mid+1,r,x,y);
return res;
}
int merge(int a,int b,int l,int r){
if(!a) return b;
if(!b) return a;
if(l==r) return data[a]+=data[b],a;
ls[a]=merge(ls[a],ls[b],l,mid);
rs[a]=merge(rs[a],rs[b],mid+1,r);
pushup(a);
return a;
}
}T;
void dfs(int x){
for(int k:G[x]){
if(k==fa[x]) continue;
dfs(k);
rt[x]=T.merge(rt[x],rt[k],1,m);
}
ans[x]+=T.ask(rt[x],1,m,a[x]+1,m);
}
int main(){
scanf("%d",&n);
for(int i=1;i<=n;i++) scanf("%d",&a[i]),b[i]=a[i];
sort(b+1,b+1+n);
m=unique(b+1,b+1+n)-b-1;
for(int i=1;i<=n;i++) a[i]=lower_bound(b+1,b+1+m,a[i])-b,rt[i]=T.insert(rt[i],1,m,a[i]);
for(int i=2;i<=n;i++){
int x;scanf("%d",&x);fa[i]=x;
G[x].push_back(i),G[i].push_back(x);
}
dfs(1);
for(int i=1;i<=n;i++) printf("%d\n",ans[i]);
return 0;
}