【ZJOI2019】语言
【ZJOI2019】语言
Description
九条可怜是一个喜欢规律的女孩子。按照规律,第二题应该是一道和数据结构有关的题。
在一个遥远的国度,有 \(n\) 个城市。城市之间有 \(n − 1\) 条双向道路,这些道路保证了任何两个城市之间都能直接或者间接地到达。
在上古时代,这 \(n\) 个城市之间处于战争状态。在高度闭塞的环境中,每个城市都发展出了自己的语言。而在王国统一之后,语言不通给王国的发展带来了极大的阻碍。为了改善这种情况,国王下令设计了 \(m\) 种通用语,并进行了 \(m\) 次语言统一工作。在第 \(i\) 次统一工作中,一名大臣从城市 \(s_i\) 出发,沿着最短的路径走到了 \(t_i\),教会了沿途所有城市(包括 \(s_i, t_i\))使用第 \(i\) 个通用语。
一旦有了共通的语言,那么城市之间就可以开展贸易活动了。两个城市 \(u_i, v_i\) 之间可以开展贸易活动当且仅当存在一种通用语 \(L\) 满足 \(u_i\) 到 \(v_i\) 最短路上的所有城市(包括 \(u_i, v_i\)),都会使用 \(L\)。
为了衡量语言统一工作的效果,国王想让你计算有多少对城市 \((u, v)\ (u < v)\),他们之间可以开展贸易活动。
Input
第一行输入两个正整数 \(n, m\),表示城市数和通用语的数量。
接下来 \(n − 1\) 行,每行两个整数 \(x_i, y_i\ (1 \le x_i, y_i \le n)\),表示了一条连接城市 \(x_i, y_i\) 的道路。
接下来 \(m\) 行,每行两个整数 \(s_i, t_i\ (1 \le s_i, t_i \le n, s_i\neq t_i)\),表示一次语言普及工作。
Output
输出一行一个整数,表示可以开展贸易活动的城市对数量。
Sample Input
5 3
1 2
1 3
3 4
3 5
3 4
1 4
2 5
Sample Output
8
Data Constraint
\(1\le n,m\le 10^5\)
Solution
第一步可以观察出,就是求经过每个点的链并大小之和
可以发现,能到达的点一定构成一颗树
所以可以向虚树那样,将涉及到的所有点按dfn排序,然后相邻求LCA
可以使用树上差分+线段树合并解决
线段树每个节点维护最小/最大的dfn以及区间的答案就行了
Code
#include<bits/stdc++.h>
using namespace std;
#define F(i,a,b) for(int i=a;i<=b;i++)
#define Fd(i,a,b) for(int i=a;i>=b;i--)
#define N 100010
#define S 10000000
int n,m,fa[N][20],dep[N],dfn[N],sz[N],rk[N],cnt;
int tot,ls[S],rs[S],sum[S],le[S],ri[S],num[S];
vector<int>e[N];
void dfs(int u,int pre){
sz[u]=1;
dfn[u]=++cnt;rk[cnt]=u;
dep[u]=dep[pre]+1;
fa[u][0]=pre;
F(i,0,18)fa[u][i+1]=fa[fa[u][i]][i];
for(auto v:e[u]){
if(v==pre)continue;
dfs(v,u);
sz[u]+=sz[v];
}
}
int lca(int x,int y){
if(!x||!y)return 0;
if(dep[x]<dep[y])swap(x,y);
Fd(i,19,0)if(dep[fa[x][i]]>=dep[y])x=fa[x][i];
if(x==y)return x;
Fd(i,19,0)if(fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
struct tree{
int root;
void ul(int x){if(!ls[x])ls[x]=++tot;}
void ur(int x){if(!rs[x])rs[x]=++tot;}
void update(int x){
if(!ls[x]){
le[x]=le[rs[x]];ri[x]=ri[rs[x]];sum[x]=sum[rs[x]];
}else
if(!rs[x]){
le[x]=le[ls[x]];ri[x]=ri[ls[x]];sum[x]=sum[ls[x]];
}else
if(ls[x]&&rs[x]){
le[x]=le[ls[x]]?le[ls[x]]:le[rs[x]];
ri[x]=ri[rs[x]]?ri[rs[x]]:ri[ls[x]];
sum[x]=sum[ls[x]]+sum[rs[x]]-dep[lca(rk[ri[ls[x]]],rk[le[rs[x]]])];
}
}
int merge(int x,int y,int l,int r){
if(!x||!y)return x|y;
if(l==r){
num[x]+=num[y];
sum[x]=num[x]>0?dep[rk[l]]:0;
le[x]=num[x]>0?l:0;
ri[x]=num[x]>0?l:0;
return x;
}
int mid=l+r>>1;
ls[x]=merge(ls[x],ls[y],l,mid);
rs[x]=merge(rs[x],rs[y],mid+1,r);
update(x);
return x;
}
void change(int x,int l,int r,int pos,int v){
if(l==r){
num[x]+=v;
sum[x]=num[x]>0?dep[rk[l]]:0;
le[x]=num[x]>0?l:0;
ri[x]=num[x]>0?l:0;
return;
}
int mid=l+r>>1;
if(pos<=mid)ul(x),change(ls[x],l,mid,pos,v);
else ur(x),change(rs[x],mid+1,r,pos,v);
update(x);
}
}t[N];
void add(int x,int y,int p){
int z=lca(x,y),fz=fa[z][0];
t[x].change(t[x].root,1,n,dfn[p],1);
t[y].change(t[y].root,1,n,dfn[p],1);
t[z].change(t[z].root,1,n,dfn[p],-1);
if(fz)t[fz].change(t[fz].root,1,n,dfn[p],-1);
}
long long ans;
void calc(int u,int pre){
for(auto v:e[u]){
if(v==pre)continue;
calc(v,u);
t[u].merge(t[u].root,t[v].root,1,n);
}
ans+=sum[t[u].root]-dep[lca(rk[le[t[u].root]],rk[ri[t[u].root]])];
}
int main(){
scanf("%d%d",&n,&m);
F(i,1,n-1){
int u,v;
scanf("%d%d",&u,&v);
e[u].push_back(v);e[v].push_back(u);
}
dfs(1,0);
F(i,1,n)t[i].root=++tot;
F(i,1,m){
int u,v;
scanf("%d%d",&u,&v);
add(u,v,u);add(u,v,v);
}
calc(1,0);
printf("%lld",ans/2);
return 0;
}