[ZJOI2019] 语言 题解
不愧是 \(ZJOI\),《最可做的一道题》都让人一头雾水……
首先将问题转化到链上。
可以将总共的组数转化为每个点可以到达的城市。
明显给每个点建一棵动态开点线段树,维护可以和他通商的点。很明显,可以通商的点的标号连续的一段。我们可以将可以将每一次传播语言的工作当作区间修改,很明显可以用差分。最后再用线段树合并从后往前计算出每一个点的答案。
那假如问题转化到树上呢?
众所周知,假如我们想要让一棵树变成多个 木棍 链,树链剖分就是我们要熟悉掌握的一个知识点。
用树链剖分的方法,就可以最多进行 \(\log_2n\) 次操作,将所有本次被传播语言的点加入线段树。
时间复杂度 \(O(n\log^2n)\)。
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=1e5+5;
const int M=2e7+5;
int n,m,k,id,d;ll ans;
int h[N],to[N*2],nxt[N*2];
int ls[M],rs[M],rt[N],v[M],lb[M];
int dep[N],fa[N],sn[N];
int dfn[N],sz[N],tp[N];
struct line{int x,y;};
struct que{int x,y,z;};
vector<line>a;
vector<que>q[N];
void ad(int x,int y){
to[++k]=y;
nxt[k]=h[x];
h[x]=k;
}void push_up(int x,int l,int r){
if(lb[x]) v[x]=r-l+1;
else v[x]=v[ls[x]]+v[rs[x]];
}void add(int &p,int l,int r,int x,int y,int kk){
if(!p) p=++d;
if(x<=l&&r<=y) lb[p]+=kk;
else{
int mid=(l+r)/2;
if(x<=mid) add(ls[p],l,mid,x,y,kk);
if(y>mid) add(rs[p],mid+1,r,x,y,kk);
}push_up(p,l,r);
}int merge(int x,int y,int l,int r){
if(!x||!y) return x|y;
lb[x]+=lb[y];
if(l<r){
int mid=(l+r)/2;
ls[x]=merge(ls[x],ls[y],l,mid);
rs[x]=merge(rs[x],rs[y],mid+1,r);
}push_up(x,l,r);
return x;
}void dfs1(int x,int f){
int mx=0;
dep[x]=dep[f]+1;
sz[x]=1;fa[x]=f;
for(int i=h[x];i;i=nxt[i]){
int y=to[i];
if(y==f) continue;
dfs1(y,x);
if(sz[y]>mx){
sn[x]=y;
mx=sz[y];
}sz[x]+=sz[y];
}
}void dfs2(int x,int f){
dfn[x]=++id;tp[x]=f;
q[x].push_back({id,id,1});
if(fa[x]) q[fa[x]].push_back({id,id,-1});
if(!sn[x]) return;
dfs2(sn[x],f);
for(int i=h[x];i;i=nxt[i])
if(to[i]!=fa[x]&&to[i]!=sn[x])
dfs2(to[i],to[i]);
}int lca(int x,int y){
a.clear();
while(tp[x]!=tp[y]){
if(dep[tp[x]]<dep[tp[y]]) swap(x,y);
a.push_back({dfn[tp[x]],dfn[x]});
x=fa[tp[x]];
}if(dep[x]>dep[y]) swap(x,y);
a.push_back({dfn[x],dfn[y]});
return x;
}void solve(int x){
for(int i=h[x];i;i=nxt[i]){
int y=to[i];
if(y==fa[x]) continue;
solve(y);
rt[x]=merge(rt[x],rt[y],1,n);
}for(int i=0;i<q[x].size();i++)
add(rt[x],1,n,q[x][i].x,q[x][i].y,q[x][i].z);
ans+=v[rt[x]]-1;
}int main(){
ios::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
cin>>n>>m;
for(int i=1,x,y;i<n;i++){
cin>>x>>y;
ad(x,y);ad(y,x);
}dfs1(1,0);dfs2(1,1);
while(m--){
int x,y,lc;
cin>>x>>y;
lc=fa[lca(x,y)];
for(int i=0;i<a.size();i++){
q[x].push_back({a[i].x,a[i].y,1});
q[y].push_back({a[i].x,a[i].y,1});
if(lc) q[lc].push_back({a[i].x,a[i].y,-2});
}
}solve(1);
cout<<ans/2;
return 0;
}