【ZJOI2019】语言

【ZJOI2019】语言

Description

九条可怜是一个喜欢规律的女孩子。按照规律,第二题应该是一道和数据结构有关的题。

在一个遥远的国度,有 \(n\) 个城市。城市之间有 \(n − 1\) 条双向道路,这些道路保证了任何两个城市之间都能直接或者间接地到达。

在上古时代,这 \(n\) 个城市之间处于战争状态。在高度闭塞的环境中,每个城市都发展出了自己的语言。而在王国统一之后,语言不通给王国的发展带来了极大的阻碍。为了改善这种情况,国王下令设计了 \(m\) 种通用语,并进行了 \(m\) 次语言统一工作。在第 \(i\) 次统一工作中,一名大臣从城市 \(s_i\) 出发,沿着最短的路径走到了 \(t_i\),教会了沿途所有城市(包括 \(s_i, t_i\))使用第 \(i\) 个通用语。

一旦有了共通的语言,那么城市之间就可以开展贸易活动了。两个城市 \(u_i, v_i\) 之间可以开展贸易活动当且仅当存在一种通用语 \(L\) 满足 \(u_i\)\(v_i\) 最短路上的所有城市(包括 \(u_i, v_i\)),都会使用 \(L\)

为了衡量语言统一工作的效果,国王想让你计算有多少对城市 \((u, v)\ (u < v)\),他们之间可以开展贸易活动。

Input

第一行输入两个正整数 \(n, m\),表示城市数和通用语的数量。
接下来 \(n − 1\) 行,每行两个整数 \(x_i, y_i\ (1 \le x_i, y_i \le n)\),表示了一条连接城市 \(x_i, y_i\) 的道路。
接下来 \(m\) 行,每行两个整数 \(s_i, t_i\ (1 \le s_i, t_i \le n, s_i\neq t_i)\),表示一次语言普及工作。

Output

输出一行一个整数,表示可以开展贸易活动的城市对数量。

Sample Input

5 3
1 2
1 3
3 4
3 5
3 4
1 4
2 5

Sample Output

8

Data Constraint

\(1\le n,m\le 10^5\)

Solution

第一步可以观察出,就是求经过每个点的链并大小之和

可以发现,能到达的点一定构成一颗树

所以可以向虚树那样,将涉及到的所有点按dfn排序,然后相邻求LCA

可以使用树上差分+线段树合并解决

线段树每个节点维护最小/最大的dfn以及区间的答案就行了

Code

#include<bits/stdc++.h>
using namespace std;
#define F(i,a,b) for(int i=a;i<=b;i++)
#define Fd(i,a,b) for(int i=a;i>=b;i--)
#define N 100010
#define S 10000000

int n,m,fa[N][20],dep[N],dfn[N],sz[N],rk[N],cnt;
int tot,ls[S],rs[S],sum[S],le[S],ri[S],num[S];
vector<int>e[N];

void dfs(int u,int pre){
	sz[u]=1;
	dfn[u]=++cnt;rk[cnt]=u;
	dep[u]=dep[pre]+1;
	fa[u][0]=pre;
	F(i,0,18)fa[u][i+1]=fa[fa[u][i]][i];
	for(auto v:e[u]){
		if(v==pre)continue;
		dfs(v,u);
		sz[u]+=sz[v];
	}
}

int lca(int x,int y){
	if(!x||!y)return 0;
	if(dep[x]<dep[y])swap(x,y);
	Fd(i,19,0)if(dep[fa[x][i]]>=dep[y])x=fa[x][i];
	if(x==y)return x;
	Fd(i,19,0)if(fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i];
	return fa[x][0];
}

struct tree{
	int root;
	void ul(int x){if(!ls[x])ls[x]=++tot;}
	void ur(int x){if(!rs[x])rs[x]=++tot;}
	void update(int x){
		if(!ls[x]){
			le[x]=le[rs[x]];ri[x]=ri[rs[x]];sum[x]=sum[rs[x]];
		}else
		if(!rs[x]){
			le[x]=le[ls[x]];ri[x]=ri[ls[x]];sum[x]=sum[ls[x]];
		}else
		if(ls[x]&&rs[x]){
			le[x]=le[ls[x]]?le[ls[x]]:le[rs[x]];
			ri[x]=ri[rs[x]]?ri[rs[x]]:ri[ls[x]];
			sum[x]=sum[ls[x]]+sum[rs[x]]-dep[lca(rk[ri[ls[x]]],rk[le[rs[x]]])];
		}
	}
	int merge(int x,int y,int l,int r){
		if(!x||!y)return x|y;
		if(l==r){
			num[x]+=num[y];
			sum[x]=num[x]>0?dep[rk[l]]:0;
			le[x]=num[x]>0?l:0;
			ri[x]=num[x]>0?l:0;
			return x;
		}
		int mid=l+r>>1;
		ls[x]=merge(ls[x],ls[y],l,mid);
		rs[x]=merge(rs[x],rs[y],mid+1,r);
		update(x);
		return x;
	}
	void change(int x,int l,int r,int pos,int v){
		if(l==r){
			num[x]+=v;
			sum[x]=num[x]>0?dep[rk[l]]:0;
			le[x]=num[x]>0?l:0;
			ri[x]=num[x]>0?l:0;
			return;
		}
		int mid=l+r>>1;
		if(pos<=mid)ul(x),change(ls[x],l,mid,pos,v);
		else ur(x),change(rs[x],mid+1,r,pos,v);
		update(x);
	}
}t[N];

void add(int x,int y,int p){
	int z=lca(x,y),fz=fa[z][0];
	t[x].change(t[x].root,1,n,dfn[p],1);
	t[y].change(t[y].root,1,n,dfn[p],1);
	t[z].change(t[z].root,1,n,dfn[p],-1);
	if(fz)t[fz].change(t[fz].root,1,n,dfn[p],-1);
}

long long ans;

void calc(int u,int pre){
	for(auto v:e[u]){
		if(v==pre)continue;
		calc(v,u);
		t[u].merge(t[u].root,t[v].root,1,n);
	}
	ans+=sum[t[u].root]-dep[lca(rk[le[t[u].root]],rk[ri[t[u].root]])];
}

int main(){
	scanf("%d%d",&n,&m);
	F(i,1,n-1){
		int u,v;
		scanf("%d%d",&u,&v);
		e[u].push_back(v);e[v].push_back(u);
	}
	dfs(1,0);
	F(i,1,n)t[i].root=++tot;
	F(i,1,m){
		int u,v;
		scanf("%d%d",&u,&v);
		add(u,v,u);add(u,v,v);
	}
	calc(1,0);
	printf("%lld",ans/2);
	return 0;
}
posted @ 2023-02-20 21:50  冰雾  阅读(29)  评论(0编辑  收藏  举报