P5327 [ZJOI2019]语言

题目

P5327 [ZJOI2019]语言

分析

线段树合并+树上差分。

首先我们发现答案其实就是:对于每一个点来说的连通块大小之和。

那么现在问题在于怎么来维护这个连通块的大小。

我们可以考虑对每一个点开一个线段树,保存:\(dfn\) 序列对应的点被路径覆盖次数和长度。

然后对于这样一类树上路径修改且每个点都要查询的问题,我们可以考虑使用线段树合并+树上差分来解决。

那么这道题就显而易见了,是直接打上两个 \(+1\) 标记,然后在 \(fa[lca]\) 处打上 \(-2\) 的标记。

接下来就是直接线段树合并,重点在于怎么具体维护有多少个点,我们发现如果当前区间的都是大于 \(0\) 的话,那么个数就是 \(r-l+1\) ,因为我们每一次修改的时候,一定是一个连续的链(从下到上)。

这是由我们树剖来决定的,也就是有 \(logn\) 个区间要进行修改的意思。

具体见代码。

代码

#include<bits/stdc++.h>
using namespace std;
template <typename T>
inline void read(T &x){
	x=0;char ch=getchar();bool f=false;
	while(!isdigit(ch)){if(ch=='-'){f=true;}ch=getchar();}
	while(isdigit(ch)){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
	x=f?-x:x;
	return ;
}
template <typename T>
inline void write(T x){
	if(x<0) putchar('-'),x=-x;
	if(x>9) write(x/10);
	putchar(x%10^48);
	return ;
}
const int N=1e5+5;
#define ll long long
int n,m;
ll Ans;
int head[N],nex[N<<1],to[N<<1],idx;
inline void add(int u,int v){
	nex[++idx]=head[u];
	to[idx]=v;
	head[u]=idx;
	return ;
}
int fa[N],dep[N],siz[N],son[N],top[N],dfn[N],rev[N],DFN;
void dfs1(int x,int f){
	fa[x]=f,dep[x]=dep[f]+1,siz[x]=1;
	for(int i=head[x];i;i=nex[i]){
		int y=to[i];
		if(y==f) continue;
		dfs1(y,x);siz[x]+=siz[y];
		if(siz[y]>siz[son[x]]) son[x]=y;
	}
	return ;
}
void dfs2(int x){
	if(x==son[fa[x]]) top[x]=top[fa[x]];
	else top[x]=x;
	dfn[x]=++DFN,rev[DFN]=x;
	if(son[x]) dfs2(son[x]);
	for(int i=head[x];i;i=nex[i]){
		int y=to[i];
		if(y==fa[x]||y==son[x]) continue;
		dfs2(y);
	}
	return ;
}
inline int QueryLca(int x,int y){
	while(top[x]!=top[y]){
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		x=fa[top[x]];
	}
	return dep[x]<dep[y]?x:y;
}
int Root[N];
struct SGT{
	int sum,num,ls,rs;
	#define sum(x) t[x].sum
	#define num(x) t[x].num
	#define ls(x) t[x].ls
	#define rs(x) t[x].rs
}t[N*250];
int cur;
void Modify(int &x,int l,int r,int ql,int qr,int v){
	if(!x) x=++cur;
	if(ql<=l&&qr>=r) return sum(x)+=v,num(x)=(sum(x)>0?(r-l+1):(num(ls(x))+num(rs(x)))),void();
	int mid=l+r>>1;
	if(ql<=mid) Modify(ls(x),l,mid,ql,qr,v);
	if(qr>mid) Modify(rs(x),mid+1,r,ql,qr,v);
	num(x)=(sum(x)>0?(r-l+1):(num(ls(x))+num(rs(x))));
	return ;
}
int Query(int x,int l,int r,int ql,int qr){
	if(!x) return 0;
	if(ql<=l&&r<=qr) return num(x);
	int mid=l+r>>1,res=0;
	if(ql<=mid) res+=Query(ls(x),l,mid,ql,qr);
	if(qr>mid) res+=Query(rs(x),mid+1,r,ql,qr);
	return res;
}
int Merge(int x,int y,int l,int r){
	if(!x||!y) return x|y;
	sum(x)+=sum(y);int mid=l+r>>1;
	ls(x)=Merge(ls(x),ls(y),l,mid),rs(x)=Merge(rs(x),rs(y),mid+1,r);
	num(x)=(sum(x)>0?(r-l+1):(num(ls(x))+num(rs(x))));
	return x;
}
typedef pair<int,int> PII;
PII path[N];
int Cnt;
void GetSeq(int x,int y){
	Cnt=0;
	while(top[x]!=top[y]){
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		path[++Cnt]=make_pair(dfn[top[x]],dfn[x]);
		x=fa[top[x]];
	}
	if(dep[x]<dep[y]) swap(x,y);
	path[++Cnt]=make_pair(dfn[y],dfn[x]);
	return ;
}
void Solve(int x){
	Modify(Root[x],1,n,dfn[x],dfn[x],1);
	for(int i=head[x];i;i=nex[i]){
		int y=to[i];
		if(y==fa[x]) continue;
		Solve(y);Root[x]=Merge(Root[x],Root[y],1,n);
	}
	Ans+=Query(Root[x],1,n,1,n)-1;
	Modify(Root[x],1,n,dfn[x],dfn[x],-1);
	return ;
}
int main(){
	read(n),read(m);
	for(int i=1;i<n;i++){
		int u,v;read(u),read(v);
		add(u,v),add(v,u);
	}
	dfs1(1,0);dfs2(1);
	for(int i=1;i<=m;i++){
		int s,t;read(s),read(t);
		int lca=QueryLca(s,t),f=fa[lca];
		GetSeq(s,t);
		for(int j=1;j<=Cnt;j++){
			Modify(Root[s],1,n,path[j].first,path[j].second,1);
			Modify(Root[t],1,n,path[j].first,path[j].second,1);
			Modify(Root[f],1,n,path[j].first,path[j].second,-2);
		}
	}
	Solve(1);
	write(Ans/2);
	return 0;
}
posted @ 2021-05-17 21:02  __Anchor  阅读(37)  评论(0编辑  收藏  举报