luogu P5327 [ZJOI2019]语言

传送门

显然的想法是对每个点求出能通过某种语言到的点个数,然后加起来\(/2\)就是答案.每次加入一条路径,就可以更新路径上所有点到达其他点的状态.那个我们用线段树维护,每次对路径上所有点的线段树上该路径对应的dfn区间覆盖(用树剖处理),最后统计每个线段树上有值的位置个数

注意每次是对一条路径上的线段树操作,路径修改可以联想到树上差分,即两端点做正权修改,lca和lca的父亲做负权修改.本题类似,我们在两端点处对对应区间执行+1操作,在lca和lca的父亲处对对应区间执行-1,然后套个线段树合并,就可以在每个节点处统计答案

// luogu-judger-enable-o2
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdlib>
#include<cstdio>
#include<vector>
#include<cmath>
#include<ctime>
#include<queue>
#include<map>
#include<set>
#define LL long long
#define db double

using namespace std;
const int N=1e5+10;
int rd()
{
    int x=0,w=1;char ch=0;
    while(ch<'0'||ch>'9'){if(ch=='-') w=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+(ch^48);ch=getchar();}
    return x*w;
}
int to[N<<1],nt[N<<1],hd[N],tot=1;
void add(int x,int y)
{
    ++tot,to[tot]=y,nt[tot]=hd[x],hd[x]=tot;
    ++tot,to[tot]=x,nt[tot]=hd[y],hd[y]=tot;
}
int fa[N],de[N],sz[N],hs[N],top[N],dfn[N],ti;
void dfs1(int x)
{
    sz[x]=1;
    for(int i=hd[x];i;i=nt[i])
    {
        int y=to[i];
        if(y==fa[x]) continue;
        fa[y]=x,de[y]=de[x]+1,dfs1(y);
        sz[x]+=sz[y],hs[x]=sz[hs[x]]>sz[y]?hs[x]:y;
    }
}
void dfs2(int x,int ntp)
{
    dfn[x]=++ti,top[x]=ntp;
    if(hs[x]) dfs2(hs[x],ntp);
    for(int i=hd[x];i;i=nt[i])
    {
        int y=to[i];
        if(y==fa[x]||y==hs[x]) continue;
        dfs2(y,y);
    }
}
int sb[N*200],ch[N*200][2],tg[N*200],rt[N],tt;
#define mid ((l+r)>>1)
void psup(int o,int len){sb[o]=tg[o]?len:sb[ch[o][0]]+sb[ch[o][1]];}
void modif(int &o,int l,int r,int ll,int rr,int x)
{
    if(!o) o=++tt;
    if(ll<=l&&r<=rr){tg[o]+=x,psup(o,r-l+1);return;}
    if(ll<=mid) modif(ch[o][0],l,mid,ll,rr,x);
    if(rr>mid) modif(ch[o][1],mid+1,r,ll,rr,x);
    psup(o,r-l+1);
}
int merge(int o1,int o2,int l,int r)
{
    if(!o1||!o2) return o1+o2;
    tg[o1]+=tg[o2];
    ch[o1][0]=merge(ch[o1][0],ch[o2][0],l,mid);
    ch[o1][1]=merge(ch[o1][1],ch[o2][1],mid+1,r);
    psup(o1,r-l+1);
    return o1;
}
struct qu
{
    int l,r,x;
};
vector<qu> qq[N];
int glca(int x,int y)
{
    while(top[x]!=top[y])
    {
        if(de[top[x]]<de[top[y]]) swap(x,y);
        x=fa[top[x]];
    }
    return de[x]<de[y]?x:y;
}
int n,m;
LL ans;
void dfs3(int x)
{
    for(int i=hd[x];i;i=nt[i])
    {
        int y=to[i];
        if(y==fa[x]) continue;
        dfs3(y),rt[x]=merge(rt[x],rt[y],1,n);
    }
    int nn=qq[x].size();
    for(int i=0;i<nn;++i) modif(rt[x],1,n,qq[x][i].l,qq[x][i].r,qq[x][i].x);
    ans+=max(sb[rt[x]]-1,0);
}

int main()
{
    n=rd(),m=rd();
    for(int i=1;i<n;++i) add(rd(),rd());
    de[1]=1,dfs1(1),dfs2(1,1);
    while(m--)
    {
        int x=rd(),y=rd(),lca=glca(x,y),xx=x,yy=y;
        while(top[xx]!=top[yy])
        {
            if(de[top[xx]]<de[top[yy]]) swap(xx,yy);
            qq[x].push_back((qu){dfn[top[xx]],dfn[xx],1});
            qq[y].push_back((qu){dfn[top[xx]],dfn[xx],1});
            qq[lca].push_back((qu){dfn[top[xx]],dfn[xx],-1});
            qq[fa[lca]].push_back((qu){dfn[top[xx]],dfn[xx],-1});
            xx=fa[top[xx]];
        }
        if(de[xx]>de[yy]) swap(xx,yy);
        qq[x].push_back((qu){dfn[xx],dfn[yy],1});
        qq[y].push_back((qu){dfn[xx],dfn[yy],1});
        qq[lca].push_back((qu){dfn[xx],dfn[yy],-1});
        qq[fa[lca]].push_back((qu){dfn[xx],dfn[yy],-1});
    }
    dfs3(1);
    printf("%lld\n",ans>>1);
    return 0;
}
posted @ 2019-05-05 16:07  ✡smy✡  阅读(160)  评论(2编辑  收藏  举报