洛谷P5327 [ZJOI2019]语言

每个点的答案为所有经过该点的链的并的大小。得链并即为所有经过该点的链的端点构成的最小连通块,设端点按 \(dfs\) 序排序后为 \(a_i\),得最小连通块的边数为:

\[\large \sum_{i=1}^{cnt} dep_{a_i}-\sum_{i=1}^{cnt-1} dep_{\operatorname{lca}(a_i,a_{i+1})}-dep_{\operatorname{lca}(a_1,a_{cnt})} \]

即所有端点的深度减去排序后相邻点的 \(\operatorname{lca}\) 的深度。

用线段树维护 \(dfs\) 序,添加路径用树上差分,更新信息用线段树合并即可。

#include<bits/stdc++.h>
#define maxn 200010
#define maxm 8000010
#define mid ((l+r)>>1)
using namespace std;
typedef long long ll;
template<typename T> inline void read(T &x)
{
    x=0;char c=getchar();bool flag=false;
    while(!isdigit(c)){if(c=='-')flag=true;c=getchar();}
    while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
    if(flag)x=-x;
}
int n,m,cnt,tot;
int rt[maxn],f[maxn][19],dep[maxn],dfn[maxn],rev[maxn];
int ls[maxm],rs[maxm],val[maxm],mx[maxm],mn[maxm];
ll ans;
ll sum[maxm];
struct edge
{
    int to,nxt;
}e[maxn];
int head[maxn],edge_cnt;
void add(int from,int to)
{
    e[++edge_cnt]={to,head[from]},head[from]=edge_cnt;
}
void dfs_pre(int x,int fa)
{
    dep[x]=dep[f[x][0]=fa]+1,rev[dfn[x]=++cnt]=x;
    for(int i=1;i<=17;++i) f[x][i]=f[f[x][i-1]][i-1];
    for(int i=head[x];i;i=e[i].nxt)
    {
        int y=e[i].to;
        if(y==fa) continue;
        dfs_pre(y,x);
    }
}
int lca(int x,int y)
{
    if(dep[x]<dep[y]) swap(x,y);
    for(int i=17;i>=0;--i)
        if(f[x][i]&&dep[f[x][i]]>=dep[y])
            x=f[x][i];
    if(x==y) return x;
    for(int i=17;i>=0;--i)
        if(f[x][i]&&f[x][i]!=f[y][i])
            x=f[x][i],y=f[y][i];
    return f[x][0];
}
int get(int x,int y)
{
    if(!x||!y) return 0;
    return dep[lca(rev[x],rev[y])];
}
void pushup(int cur)
{
    mx[cur]=mx[rs[cur]]?mx[rs[cur]]:mx[ls[cur]];
    mn[cur]=mn[ls[cur]]?mn[ls[cur]]:mn[rs[cur]];
    sum[cur]=sum[ls[cur]]+sum[rs[cur]]-get(mx[ls[cur]],mn[rs[cur]]);
}
void modify(int l,int r,int pos,int v,int &cur)
{
    if(!cur) cur=++tot;
    if(l==r)
    {
        if((val[cur]+=v)>0) mx[cur]=mn[cur]=l,sum[cur]=dep[rev[l]];
        else mx[cur]=mn[cur]=sum[cur]=0;
        return;
    }
    if(pos<=mid) modify(l,mid,pos,v,ls[cur]);
    else modify(mid+1,r,pos,v,rs[cur]);
    pushup(cur);
}
int merge(int x,int y,int l,int r)
{
    if(!x||!y) return x+y;
    if(l==r)
    {
        if((val[x]+=val[y])>0) mx[x]=mn[x]=l,sum[x]=dep[rev[l]];
        else mx[x]=mn[x]=sum[x]=0;
        return x;
    }
    ls[x]=merge(ls[x],ls[y],l,mid);
    rs[x]=merge(rs[x],rs[y],mid+1,r);
    pushup(x);
    return x;
}
void dfs_ans(int x)
{
    for(int i=head[x];i;i=e[i].nxt)
    {
        int y=e[i].to;
        if(y==f[x][0]) continue;
        dfs_ans(y),rt[x]=merge(rt[x],rt[y],1,n);
    }
    ans+=sum[rt[x]]-get(mx[rt[x]],mn[rt[x]]);
}
void update(int x,int y,int id)
{
    int anc=lca(x,y);
    modify(1,n,id,1,rt[x]);
    modify(1,n,id,1,rt[y]);
    modify(1,n,id,-1,rt[anc]);
    if(f[anc][0]) modify(1,n,id,-1,rt[f[anc][0]]);
}
int main()
{
    read(n),read(m);
    for(int i=1;i<n;++i)
    {
        int x,y;
        read(x),read(y);
        add(x,y),add(y,x);
    }
    dfs_pre(1,0);
    for(int i=1;i<=m;++i)
    {
        int x,y;
        read(x),read(y);
        update(x,y,dfn[x]),update(x,y,dfn[y]);
    }
    dfs_ans(1),printf("%lld",ans/2);
    return 0;
}
posted @ 2020-11-10 22:09  lhm_liu  阅读(208)  评论(0编辑  收藏  举报