洛谷 P5327 [ZJOI2019]语言
洛谷 P5327 [ZJOI2019]语言
https://www.luogu.com.cn/problem/P5327
Tutorial
https://www.luogu.com.cn/blog/Sooke/solution-p5327
考虑如果 \(n,m \le 5 \times 10^3\) 怎么做.
对于一个点 \(u\) ,如果我们将所有经过它的 \(s,t\) 点拿出来,发现它所可以到达的区域实际就是这些点的虚树的大小.
虚树的大小可以在dfs序上用线段树维护,默认 \(1\) 节点在虚树中,每个区间维护区间内虚树大小,dfs序最小\(mn\)和最大的节点\(mx\),合并的时候减去左边的\(mx\)和右边的\(mn\)的lca深度,计算答案时减去根节点\(mn,mx\)的lca深度即可.
考虑\(n,m\le10^5\)的时候,我们可以用线段树合并来维护每个节点的虚树,将所有路径在树上差分一下即可.
Code
#include <cstdio>
#include <cstring>
#include <iostream>
#include <vector>
#define debug(...) fprintf(stderr,__VA_ARGS__)
using namespace std;
inline char gc() {
// return getchar();
static char buf[100000],*l=buf,*r=buf;
return l==r&&(r=(l=buf)+fread(buf,1,100000,stdin),l==r)?EOF:*l++;
}
template<class T> void rd(T &x) {
x=0; int f=1,ch=gc();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=gc();}
while(ch>='0'&&ch<='9'){x=x*10-'0'+ch;ch=gc();}
x*=f;
}
typedef long long ll;
const int maxn=1e5+50;
const int maxdfc=maxn<<1;
int n,m; ll an;
int head[maxn];
int dfc,dfn[maxn],dep[maxn],a[20][maxdfc];;
vector<int> ad[maxn],de[maxn];
int root[maxn];
struct edge {
int to,nex;
edge(int to=0,int nex=0):to(to),nex(nex){}
};
vector<edge> G;
inline void addedge(int u,int v) {
G.push_back(edge(v,head[u])),head[u]=G.size()-1;
G.push_back(edge(u,head[v])),head[v]=G.size()-1;
}
namespace rmq {
int bit[20],lg2[maxdfc];
inline int cmp(int a,int b) {return dep[a]<dep[b]?a:b;}
void dfs(int u,int fa) {
a[0][dfn[u]=++dfc]=u;
for(int i=head[u];~i;i=G[i].nex) {
int v=G[i].to; if(v==fa) continue;
dep[v]=dep[u]+1;
dfs(v,u);
a[0][++dfc]=u;
}
}
void init() {
dfs(1,0);
bit[0]=1;
for(int i=1;i<20;++i) bit[i]=bit[i-1]<<1;
lg2[0]=-1;
for(int i=1;i<=dfc;++i) lg2[i]=lg2[i>>1]+1;
for(int k=1;bit[k]<=dfc;++k) {
for(int i=1;i+bit[k]-1<=dfc;++i) {
a[k][i]=cmp(a[k-1][i],a[k-1][i+bit[k-1]]);
}
}
}
inline int query(int l,int r) {
int k=lg2[r-l+1];
return cmp(a[k][l],a[k][r-bit[k]+1]);
}
inline int lca(int u,int v) {
if(dfn[u]>dfn[v]) swap(u,v);
return query(dfn[u],dfn[v]);
}
}
namespace seg {
const int maxnode=maxn*100;
int ncnt;
struct node {
int ls,rs,cnt,mn,mx,val;
node() {mn=mx=-1;}
void doit(node other) {
cnt=other.cnt,mn=other.mn,mx=other.mx,val=other.val;
}
} tree[maxnode];
inline void pushup(int u) {
int ls=tree[u].ls,rs=tree[u].rs;
if(tree[ls].mn==-1) {tree[u].doit(tree[rs]); return;}
if(tree[rs].mn==-1) {tree[u].doit(tree[ls]); return;}
tree[u].mn=tree[ls].mn,tree[u].mx=tree[rs].mx;
tree[u].val=tree[ls].val+tree[rs].val-dep[rmq::lca(tree[ls].mx,tree[rs].mn)];
}
void update(int &u,int l,int r,int qp,int qv) {
if(!u) u=++ncnt;
if(l==r) {
tree[u].cnt+=qv;
if(tree[u].cnt==0) tree[u]=node();
else {
tree[u].mn=tree[u].mx=a[0][qp];
tree[u].val=dep[a[0][qp]];
}
return;
}
int mid=(l+r)>>1;
if(qp<=mid) update(tree[u].ls,l,mid,qp,qv);
else update(tree[u].rs,mid+1,r,qp,qv);
pushup(u);
}
void merge(int &u,int v,int l,int r) {
if(u==0||v==0) {u=u+v; return;}
if(l==r) {
tree[u].cnt+=tree[v].cnt;
if(tree[u].cnt==0) tree[u]=node();
else {
tree[u].mn=tree[u].mx=a[0][l];
tree[u].val=dep[a[0][l]];
}
return;
}
int mid=(l+r)>>1;
merge(tree[u].ls,tree[v].ls,l,mid);
merge(tree[u].rs,tree[v].rs,mid+1,r);
pushup(u);
}
inline int sol(int u) {
if(tree[u].mn==-1) return 0;
return tree[u].val-dep[rmq::lca(tree[u].mn,tree[u].mx)]+1;
}
}
void dfs(int u,int fa) {
for(int i=0;i<ad[u].size();++i) {
seg::update(root[u],1,dfc,dfn[ad[u][i]],1);
}
for(int i=head[u];~i;i=G[i].nex) {
int v=G[i].to; if(v==fa) continue;
dfs(v,u);
seg::merge(root[u],root[v],1,dfc);
}
an+=seg::sol(root[u]);
for(int i=0;i<de[u].size();++i) {
seg::update(root[u],1,dfc,dfn[de[u][i]],-1);
}
}
int main() {
rd(n),rd(m);
memset(head,-1,sizeof(head));
for(int i=1;i<n;++i) {
int u,v; rd(u),rd(v);
addedge(u,v);
}
rmq::init();
for(int i=1;i<=n;++i) {
ad[i].push_back(i);
de[i].push_back(i);
}
for(int i=1;i<=m;++i) {
int s,t,w; rd(s),rd(t),w=rmq::lca(s,t);
ad[s].push_back(s),ad[s].push_back(t);
ad[t].push_back(s),ad[t].push_back(t);
de[w].push_back(s),de[w].push_back(s);
de[w].push_back(t),de[w].push_back(t);
}
dfs(1,0);
an=(an-n)/2;
printf("%d\n",an);
return 0;
}