[ZJOI2019]语言
Link
Description
\(n\) 个点的树,有 \(m\) 条路径。询问有多少个点对(u>v) \((u,v)\),存在至少一条路径覆盖这两个点。
Solution
我们不可能枚举点对,而且也不能枚举路径,因为路径有交。所以考虑枚举其中一个点 ,然后统计有多少个 \((u,v)\)。
那么就有一个暴力的做法了,枚举点 \(u\),枚举路径,如果路径经过 \(u\),就把该路径上的点打上标记,最后查询有多少个点有标记。可以用树剖实现,复杂度 \(O(n^2 \log^2 n)\) 。
进一步观察,我们发现有很多修改是重复且没必要的。具体地说,一条链影响的是链上所有的点。而这些点显然是连续的。也就是说,父亲和儿子的操作有很大的交集。有什么办法可以使得尽可能多地减少操作次数呢。我们想到了父亲可以直接继承儿子的线段树。所以就有了线段树合并。对于加链的话,发现只需要在遍历链的端点处的时候修改,父亲直接继承即可,这样就省掉了重复修改的时间。但直接继承的话,可能儿子包含某条链,但父亲不包含,所以需要在遍历父亲的时候将其删除。那么这个“父亲”到底是哪个节点?容易发现,其实就是链两端 LCA 的父亲。于是我们又发明了差分。
时间复杂度 \(O(n\log^2 n)\)。注意到,树剖的区间修改每次实际上最多会新建 \(4\log n\) 个节点,所以空间需要开大点。
#include<stdio.h>
#include<vector>
#include<cassert>
using namespace std;
inline int read(){
int x=0,flag=1; char c=getchar();
while(c<'0'||c>'9'){if(c=='-') flag=0;c=getchar();}
while(c>='0'&&c<='9'){x=(x<<1)+(x<<3)+c-48;c=getchar();}
return flag? x:-x;
}
const int N=1e5+7;
struct Node{
int s,cnt,ls,rs,lf,rf;
}t[N*120];
struct Path{int u,v,cnt;};
vector<Path> P[N];
struct E{
int next,to;
}e[N<<1];
int n,m,head[N],cnt=0,rt[N];
int fa[N][20],sz[N],in[N],dep[N],son[N],tp[N];
inline void add(int id,int to){
e[++cnt]=(E){head[id],to};
head[id]=cnt;
}
void dfs(int u){
sz[u]=1;
for(int i=1;i<=17;i++)
fa[u][i]=fa[fa[u][i-1]][i-1];
for(int i=head[u];i;i=e[i].next){
int v=e[i].to;
if(v==fa[u][0]) continue;
dep[v]=dep[u]+1;
fa[v][0]=u,dfs(v);
sz[u]+=sz[v];
if(sz[v]>sz[son[u]]) son[u]=v;
}
}
void Dfs(int u,int top){
static int tot=0;
in[u]=++tot,tp[u]=top;
if(!son[u]) return;
Dfs(son[u],top);
for(int i=head[u];i;i=e[i].next){
int v=e[i].to;
if(v==fa[u][0]||v==son[u]) continue;
Dfs(v,v);
}
}
inline void swap(int &x,int &y){x^=y,y^=x,x^=y;}
inline int Lca(int u,int v){
if(u==v) return u;
if(dep[u]<dep[v]) swap(u,v);
for(int i=17;~i;i--)
if(dep[fa[u][i]]>=dep[v]) u=fa[u][i];
if(u==v) return u;
for(int i=17;~i;i--)
if(fa[u][i]!=fa[v][i])
u=fa[u][i],v=fa[v][i];
return fa[u][0];
}
inline void update(int id){
t[id].s=(t[id].cnt? (t[id].rf-t[id].lf+1):(t[t[id].ls].s+t[t[id].rs].s));
}
inline void push(int id,int v){t[id].cnt+=v,update(id);}
int merge(int x,int y,int lf=1,int rf=n){
if(!x||!y) return x+y;
t[x].cnt+=t[y].cnt;
if(lf!=rf){
int mid=(lf+rf)>>1;
t[x].ls=merge(t[x].ls,t[y].ls,lf,mid);
t[x].rs=merge(t[x].rs,t[y].rs,mid+1,rf);
}
return update(x),x;
}
int L,R,C;
void modify(int &id,int lf=1,int rf=n){
static int tot=0;
// assert(tot<=8000000);
if(!id) id=++tot,t[id].lf=lf,t[id].rf=rf;
if(L<=lf&&rf<=R) push(id,C);
else{
int mid=(lf+rf)>>1;
if(L<=mid) modify(t[id].ls,lf,mid);
if(R>mid) modify(t[id].rs,mid+1,rf);
update(id);
}
}
inline void Modify(int &Rt,int u,int v){
while(tp[u]!=tp[v]){
if(dep[tp[u]]<dep[tp[v]]) swap(u,v);
L=in[tp[u]],R=in[u],modify(Rt);
u=fa[tp[u]][0];
}
if(dep[u]<dep[v]) swap(u,v);
L=in[v],R=in[u],modify(Rt);
}
long long ans=0;
void DFS(int u){
for(int i=head[u];i;i=e[i].next){
int v=e[i].to;
if(v==fa[u][0]) continue;
DFS(v);
rt[u]=merge(rt[u],rt[v]);
}
for(unsigned int i=0;i<P[u].size();i++)
C=P[u][i].cnt,Modify(rt[u],P[u][i].u,P[u][i].v);
ans+=t[rt[u]].s>0? 1ll*t[rt[u]].s-1:0;
}
int main(){
n=read(),m=read();
for(int i=1;i<n;i++){
int u=read(),v=read();
add(u,v),add(v,u);
}
dep[1]=1,dfs(1),Dfs(1,1);
for(int i=1;i<=m;i++){
int u=read(),v=read();
if(u==v) continue;
int lca=Lca(u,v);
P[u].push_back((Path){u,v,1});
P[v].push_back((Path){u,v,1});
P[fa[lca][0]].push_back((Path){u,v,-2});
}
DFS(1);
printf("%lld",ans>>1);
}