线段树合并学习笔记

线段树合并

过程:

顾名思义,线段树合并是指建立一棵新的线段树,这棵线段树的每个节点都是两棵原线段树对应节点合并后的结果。它常常被用于维护树上或是图上的信息。

一般每个点建一棵线段树,以子树或者题目要求进行合并(比如连通块)。

实现:

我们考虑每次递归合并。把线段树 \(b\) 上的信息和线段树 \(a\) 上的信息合并更新。并且递归下去。如果某个线段树节点为空则直接返回。

例题:

P4556 [Vani有约会] 雨天的尾巴 /【模板】线段树合并

题意

给一棵树,一共 \(m\) 次操作,操作点对 \((x,y,z)\) 表示在 \(x\to y\) 的路径上每个点对于 \(x\) 种类计数器 \(+1\)。最后请输出对于这棵树上每个点计数器值最大的种类。

解法

我们考虑一个暴力的想法,即我们对于 \(x\to1\) 的路径上 \(cnt_{z}+1\),对于 \(y\to 1\) 的路径上 \(cnt_z+1\)。对于 \(lca(x,y) \to 1\) 的路径上 \(cnt_z-1\),对于\(fa_{lca(x,y)}\to 1\) 的路径上 \(cnt_z-1\)。建 \(n\) 棵线段树,区间加,单点查。但是对于这道题无法通过。

我们考虑树上差分,即对于这样一个询问,我们仅对于 \(x\)\(cnt_{z}+1\),对于 \(y\)\(cnt_z+1\)。对于 \(lca(x,y)\)\(cnt_z-1\),对于$fa_{lca(x,y)} $ 的 \(cnt_z-1\)。这样答案可以通过子树和方式求得。

说干就干,但是子树和如果通过数组按位相加再比较实在是太慢了!所以我们考虑线段树合并,我们把答案子树 \(u\) 内答案都合并到 \(u\) 这棵线段树上来,然后直接查答案即可。

这题线段树需要动态开点,否则空间不够。

\(lca\) 使用树剖维护即可。

复杂度 \(O(n\log n+m\log n)\),足够通过本题。

代码
#include<bits/stdc++.h>
using namespace std;

const int N=100000;
#define mid ((l+r)>>1)

int n,rt[100005];
int sum[5000005],cnt=0,res[5000005],ls[5000005],rs[5000005];
int m,ans[100005];
vector<int> G[100005];
//sum 出现次数,res种类--也就是答案

struct tree{
	
	int top[5000005];
	int fa[5000005],bgs[5000005],dep[5000005],siz[5000005];
	
	void dfs(int x,int fat){
	  dep[x]=dep[fat]+1;
	  fa[x]=fat;siz[x]=1;
	  for(int k:G[x]){
	    if(k!=fat){
	      dfs(k,x);
	      siz[x]+=siz[k];
	      if(siz[k]>siz[bgs[x]]) bgs[x]=k;
	    }
	  }
	}
	
	void DFS(int x,int fat,int tp){
	  top[x]=tp;
	  if(bgs[x]){
	      DFS(bgs[x],x,tp);
	  }
	  for(int k:G[x]){
	      if(k!=fat&&k!=bgs[x]) DFS(k,x,k);
	  }
	}
	
	int lca(int x,int y){
	  while(top[x]^top[y]) dep[top[x]]>dep[top[y]]?x=fa[top[x]]:y=fa[top[y]];
	  return dep[x]<dep[y]?x:y;
	}
	
}_sp;

struct seg_tree{
	
	int merge(int a,int b,int l,int r){
		if(!a) return b;
		if(!b) return a;
		if(l==r) return sum[a]+=sum[b],a;
		ls[a]=merge(ls[a],ls[b],l,mid),rs[a]=merge(rs[a],rs[b],mid+1,r);
		pushup(a,ls[a],rs[a]);
		return a;
	}
	
	void pushup(int k,int l,int r){
		if(sum[ls[k]]<sum[rs[k]]) res[k]=res[rs[k]],sum[k]=sum[rs[k]];
		else res[k]=res[ls[k]],sum[k]=sum[ls[k]];
	}
	
	int modify(int k,int l,int r,int co,int val){
		if(!k) k=++cnt;
		if(l==r) return sum[k]+=val,res[k]=co,k;
		if(co<=mid) ls[k]=modify(ls[k],l,mid,co,val);
		else rs[k]=modify(rs[k],mid+1,r,co,val);
		pushup(k,l,r);
		return k;
	}
	
	void get_ans(int x){
		for(int k:G[x]){
			if(k==_sp.fa[x]) continue;
			get_ans(k);
			rt[x]=merge(rt[x],rt[k],1,100000); 
		}
		ans[x]=res[rt[x]];
		if(sum[rt[x]]==0) ans[x]=0;
	}
	
}T;

void read(){
	scanf("%d%d",&n,&m);
	for(int i=1;i<n;i++){
		int u,v;scanf("%d%d",&u,&v);
		G[u].push_back(v),G[v].push_back(u);
	}
	_sp.dfs(1,0),_sp.DFS(1,0,1);
}

void solve(){
	for(int i=1;i<=m;i++){
		int x,y,z;
		scanf("%d%d%d",&x,&y,&z);
		rt[x]=T.modify(rt[x],1,N,z,1),rt[y]=T.modify(rt[y],1,N,z,1);
		int _lca=_sp.lca(x,y);
//		printf("lca:%d\n",_lca);
		rt[_lca]=T.modify(rt[_lca],1,N,z,-1);
		rt[_sp.fa[_lca]]=T.modify(rt[_sp.fa[_lca]],1,N,z,-1);
	}
//	return;
	T.get_ans(1);
	
	for(int i=1;i<=n;i++) printf("%d\n",ans[i]);
}

int main(){
//	freopen("data.in","r",stdin);
	read();
//	return 0;
	solve();
	return 0;
}

P3224 [HNOI2012] 永无乡

题意

给定一个图,每个点都有一个重要度。两种操作,操作一连接两个点,操作二求对于某个点连通的所有点中排名第 \(y\) 小的是哪个。

解法

考虑用并查集维护连通性。对于每个连通块用线段树维护重要度。

  • 对于每次联通操作,先并查集合并,再线段树合并。

  • 对于每次查询,我们在所对应的连通块所对应的线段树上二分查找即可。

复杂度 \(O(n\log n)\),足以通过本题。线段树要动态开点。注意合并时最后要 return a

代码
#include<bits/stdc++.h>
using namespace std;
#define mid ((l+r)>>1)

const int N=1e5+7;
const int M=3e6+7;

int fa[N];
int n,m,q;
int rt[N],cnt,ls[M],rs[M],rnk[M],sum[M];
int x,y;

struct node{
	int find(int x){return fa[x]==x?x:fa[x]=find(fa[x]);}
}bcj;

struct seg_tree{
	
	void pushup(int k,int l,int r){
		sum[k]=sum[ls[k]]+sum[rs[k]];
	}
	
	int modify(int k,int l,int r,int pos,int idx){
		if(!k) k=++cnt; 
		if(l==r) return rnk[k]=idx,sum[k]++,k;
		if(pos<=mid) ls[k]=modify(ls[k],l,mid,pos,idx);
		else rs[k]=modify(rs[k],mid+1,r,pos,idx);
		pushup(k,l,r);
		return k;
	}
	
	int query(int k,int l,int r,int val){
		if(sum[k]<val||!k) return 0;
		if(l==r) return rnk[k];
		if(val<=sum[ls[k]]) return query(ls[k],l,mid,val);
		else return query(rs[k],mid+1,r,val-sum[ls[k]]);
	}
	
	int merge(int a,int b,int l,int r){
		if(!a) return b;
		if(!b) return a;
		if(l==r){
			if(rnk[b]) rnk[a]=rnk[b],sum[a]+=sum[b];
			return a;
		}
		ls[a]=merge(ls[a],ls[b],l,mid);
		rs[a]=merge(rs[a],rs[b],mid+1,r);
		pushup(a,l,r);
		return a;
	}
	
}T;

int main(){
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++){
		fa[i]=i;int x;scanf("%d",&x);
		rt[i]=T.modify(rt[i],1,n,x,i);
	}
	for(int i=1;i<=m;i++){
		int x,y;scanf("%d%d",&x,&y);
		x=bcj.find(x),y=bcj.find(y);
		fa[y]=x;
		rt[x]=T.merge(rt[x],rt[y],1,n);
	}
	char s[5];
	scanf("%d",&q);
	while(q--){
		scanf("%s",s);scanf("%d%d",&x,&y);
		if(s[0]=='B'){
			x=bcj.find(x),y=bcj.find(y);
			if(x==y) continue;
			fa[y]=x;
			rt[x]=T.merge(rt[x],rt[y],1,n);
		}
		else{
			x=bcj.find(x);
			int ans=T.query(rt[x],1,n,y);
			if(!ans) printf("-1\n");
			else printf("%d\n",ans);
		}
	}
	return 0;
}

P3899 [湖南集训] 更为厉害

题意

求对于一个点编号为 \(p\) 的点 \(a\),求树上有多少三元组 \((a,b,c)\) 满足 \(a,b\)\(c\) 的祖先,并且 \(dis(a,b)\le k\)。(\(k\) 为给定的常数)。

分析

我们考虑分类讨论。由题意可得,\(a,b,c\) 一定在一条链上。

  • \(b\)\(a\) 上方,则最多可以跳到 \(dep_a-k\) 处(不跳出去的情况下)。因而此时答案为 \((size_a-1)\times \min(dep_a-1,k)\)
  • \(b\)\(a\) 下方,那么 \(b\) 可以在 \(a\) 子树深度为 \([deep_a+1,deep_a+k]\) 范围内,此时 \(c\) 的数量就是 \(size_b-1\),因而我们考虑建立权值线段树,求下标为 \([deep_a+1,deep_a+k]\) 区间内 \(size-1\) 的和即可。对于每个点维护每个深度的答案即可,考虑子树内所有点所在线段树需要向该子树根节点做线段树合并。
代码
#include<bits/stdc++.h>
using namespace std;

#define int long long
#define mid ((l+r)>>1)

const int N=3e5+7;

int n,q,rt[N],cnt,dep[N],sz[N];
int ans[N];
vector<int> G[N];
vector<pair<int,int> > Q[N];
int ls[N*10<<2],rs[N*10<<2],sum[N*10<<2];

struct seg_tree{
	
	void pushup(int k,int l,int r){
		sum[k]=sum[ls[k]]+sum[rs[k]];
	}
	
	int modify(int k,int l,int r,int x,int val){
		if(!k) k=++cnt;
		if(l==r) return (sum[k]+=val),k;
		if(x<=mid) ls[k]=modify(ls[k],l,mid,x,val);
		else rs[k]=modify(rs[k],mid+1,r,x,val);
		pushup(k,l,r);
		return k;
	}

	int merge(int a,int b,int l,int r){
		if(!a) return b;if(!b) return a;
		if(l==r) return (sum[a]+=sum[b]),a;
		ls[a]=merge(ls[a],ls[b],l,mid);
		rs[a]=merge(rs[a],rs[b],mid+1,r);
		pushup(a,l,r);
		return a;
	}
	
	int query(int k,int l,int r,int x,int y){
		if(!k) return 0;
		if(x<=l&&y>=r) return sum[k];
		int res=0;
		if(x<=mid) res+=query(ls[k],l,mid,x,y);
		if(y>=mid+1) res+=query(rs[k],mid+1,r,x,y);
		return res;
	}
	
}T;

void DFS(int x,int fa){
	dep[x]=dep[fa]+1,sz[x]=1;
	for(int k:G[x]){
		if(k==fa) continue;
		DFS(k,x);
		rt[x]=T.merge(rt[x],rt[k],1,n);
		sz[x]+=sz[k];
	}
	rt[x]=T.modify(rt[x],1,n,dep[x],sz[x]-1);
	for(auto it:Q[x]){
		int id=it.first,k=it.second;
		ans[id]=T.query(rt[x],1,n,min(n,dep[x]+1),min(dep[x]+k,n))+(sz[x]-1)*min(dep[x]-1,k);
	}
}

signed main(){
	scanf("%lld%lld",&n,&q);
	for(int i=1;i<n;i++){
		int u,v;scanf("%lld%lld",&u,&v);
		G[u].push_back(v),G[v].push_back(u);
	}
	for(int i=1;i<=q;i++){
		int x,k;scanf("%lld%lld",&x,&k);
		Q[x].push_back(make_pair(i,k));
	}
	DFS(1,0);
	for(int i=1;i<=q;i++) printf("%lld\n",ans[i]);
	return 0;
}

P3605 [USACO17JAN] Promotion Counting P

分析

板子。

代码
#include<bits/stdc++.h>
using namespace std;

#define mid ((l+r)>>1)

const int N=1e5+7;

int n,m;
int a[N],b[N];

int fa[N];

int data[N*40];
int ls[N*40],rs[N*40];
int tot;
int rt[N*40];

int ans[N];

vector<int> G[N];

struct seg_tree{
  void pushup(int k){data[k]=data[ls[k]]+data[rs[k]];}
  int insert(int k,int l,int r,int x){
    if(!k) k=++tot;
    if(l==r) return data[k]++,k;
    if(x<=mid) ls[k]=insert(ls[k],l,mid,x);
    else rs[k]=insert(rs[k],mid+1,r,x);
    pushup(k);
    return k;
  }
  int ask(int k,int l,int r,int x,int y){
    if(x<=l&&y>=r) return data[k];
    int res=0;
    if(x<=mid) res+=ask(ls[k],l,mid,x,y);
    if(y>=mid+1) res+=ask(rs[k],mid+1,r,x,y);
    return res;
  }
  int merge(int a,int b,int l,int r){
    if(!a) return b;
    if(!b) return a;
    if(l==r) return data[a]+=data[b],a;
    ls[a]=merge(ls[a],ls[b],l,mid);
    rs[a]=merge(rs[a],rs[b],mid+1,r);
    pushup(a);
    return a;
  }
}T;

void dfs(int x){
  for(int k:G[x]){
    if(k==fa[x]) continue;
    dfs(k);
    rt[x]=T.merge(rt[x],rt[k],1,m);
  }
  ans[x]+=T.ask(rt[x],1,m,a[x]+1,m);
}

int main(){
  scanf("%d",&n);
  for(int i=1;i<=n;i++) scanf("%d",&a[i]),b[i]=a[i];
  sort(b+1,b+1+n);
  m=unique(b+1,b+1+n)-b-1;
  for(int i=1;i<=n;i++) a[i]=lower_bound(b+1,b+1+m,a[i])-b,rt[i]=T.insert(rt[i],1,m,a[i]);
  for(int i=2;i<=n;i++){
    int x;scanf("%d",&x);fa[i]=x;
    G[x].push_back(i),G[i].push_back(x);
  }
  dfs(1);
  for(int i=1;i<=n;i++) printf("%d\n",ans[i]);
  return 0;
}
posted @ 2023-07-25 13:44  Zimo_666  阅读(74)  评论(0编辑  收藏  举报