Atcoder beginner contest 163 f path pass i

传送门:https://atcoder.jp/contests/abc163/tasks/abc163_f

题目大意:一颗n个节点的树,每个节点有一个颜色。求对每一个颜色,至少经过一个该颜色节点的简单路径数量。

  分析:虽然有O(n)的做法,但是这里还是贴一下虚树的做法。虚树的做法大概是:对每一种颜色建立虚树,对于每一个标记好的节点,分别统计其子树的非标记节点联通块大小,总数减去这样的情况。这里是通过子树大小减去子树中标记节点的子树大小来统计的。

#include<bits/stdc++.h>

#define all(x) x.begin(),x.end()
#define fi first
#define sd second
#define lson (nd<<1)
#define rson (nd+nd+1)
#define PB push_back
#define mid (l+r>>1)
#define MP make_pair
#define SZ(x) (int)x.size()

using namespace std;

typedef long long LL;

typedef vector<int> VI;

typedef pair<int,int> PII;

inline int read(){
    int res=0, f=1;char ch=getchar();
    while(ch<'0'|ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){res=res*10+ch-'0';ch=getchar();}
    return res*f;
}

const int MAXN = 200'005;

const int MOD = 1000000007;

void addmod(int& a, int b){a+=b;if(a>=MOD)a-=MOD;}
int mulmod(int a, int b){return 1ll*a*b%MOD;}

template<typename T>
void chmin(T& a, T b){if(a>b)a=b;}

template<typename T>
void chmax(T& a, T b){if(b>a)a=b;}

#define go(e,u) for(int e=head[u];e;e=Next[e])
int to[MAXN<<1],Next[MAXN<<1],head[MAXN],tol;

void add_edge(int u,int v){
    Next[++tol]=head[u];to[tol]=v;head[u]=tol;
    Next[++tol]=head[v];to[tol]=u;head[v]=tol;
}

#define gov(e,u) for(int e=headv[u];e;e=Nextv[e])
int tov[MAXN<<1],Nextv[MAXN<<1],headv[MAXN],tolv;

void add_edgev(int u,int v){
    Nextv[++tolv]=headv[u];tov[tolv]=v;headv[u]=tolv;
}

int n, col[MAXN];

vector<int> nodes[MAXN];

int dfn[MAXN], R[MAXN], dfncnt;
int up[MAXN][25], dep[MAXN], st[MAXN], top, sz[MAXN];
int mark[MAXN];

LL ans;

void dfs(int u, int f){
    sz[u]=1;
    dfn[u]=++dfncnt;

    for(int i=0;up[u][i];++i)up[u][i+1]=up[up[u][i]][i];

    go(e,u){
        int v=to[e];
        if(v==f)continue;
        up[v][0]=u;
        dep[v]=dep[u]+1;
        dfs(v,u);
        sz[u]+=sz[v];
    }
    R[u]=dfncnt;
}

int getLCA(int u, int v){
    if(dep[u]<dep[v])swap(u,v);

    for(int i=20;i>=0;--i){
        if(dep[up[u][i]]>=dep[v]){
            u=up[u][i];
        }
    }

    if(u==v)return u;

    for(int i=20;i>=0;--i){
        if(up[u][i]!=up[v][i]){
            u=up[u][i];
            v=up[v][i];
        }
    }

    return up[u][0];
}

bool cmp(int x, int y){
    return dfn[x]<dfn[y];
}

bool cmp2(PII x, PII y){//未排序,wa
    return x.fi<y.fi;
}

LL dfs1(int u){
    LL s=0;

    vector<PII> num;
    gov(e,u){
        int v=tov[e];
        LL t=dfs1(v);
        s+=t;
        if(mark[u]) num.PB(MP(dfn[v],t));
    }

    sort(all(num),cmp2);

    if(mark[u]){
        int idx=0;
        go(e,u){
            int v=to[e];
            if(v==up[u][0])continue;
            int cc=0;
            while(idx<SZ(num)&&num[idx].fi>=dfn[v]&&num[idx].fi<=R[v]){
                cc+=num[idx].sd;
                ++idx;
            }

            ans-=1ll*(sz[v]-cc)*(sz[v]-cc+1)/2;
        }
    }

    if(u==1&&!mark[u]){
        LL num=sz[1]-s;
        ans-=1ll*num*(num+1)/2;
    }

    if(mark[u])return sz[u];
    else return s;
}

int main(){
    n=read();
    for(int i=1;i<=n;++i){
        col[i]=read();
        nodes[col[i]].PB(i);
    }

    for(int i=1;i<n;++i){
        int u=read(),v=read();
        add_edge(u,v);
    }

    dep[1]=1;
    dfs(1,0);
    for(int color=1;color<=n;++color){
        if(!SZ(nodes[color])){
            cout<<0<<endl;
            continue;
        }

        ans=1ll*n*(n+1)/2;
        sort(all(nodes[color]),cmp);

        //建立虚树
        st[top=1]=1;headv[1]=0;tolv=0;
        for(int i=0;i<SZ(nodes[color]);++i){
            int nn=nodes[color][i];
            mark[nn]=1;
            if(nn==1)continue;

            int l=getLCA(st[top],nn);

            if(l!=st[top]){
                while(dfn[l]<dfn[st[top-1]]){
                    add_edgev(st[top-1],st[top]);
                    --top;
                }
                if(dfn[l]>dfn[st[top-1]]){
                    headv[l]=0;add_edgev(l,st[top]);st[top]=l;
                }else{
                    add_edgev(l,st[top--]);
                }
            }
            headv[nn]=0;st[++top]=nn;
        }

        for(int i=1;i<top;++i){
            add_edgev(st[i],st[i+1]);
        }

        dfs1(st[1]);
        cout<<ans<<endl;
        for(int i=0;i<SZ(nodes[color]);++i)mark[nodes[color][i]]=0;
    }

    return 0;
}
View Code

BTW,O(n)的做法。

#include<bits/stdc++.h>

#define all(x) x.begin(),x.end()
#define fi first
#define sd second
#define lson (nd<<1)
#define rson (nd+nd+1)
#define PB push_back
#define mid (l+r>>1)
#define MP make_pair
#define SZ(x) (int)x.size()

using namespace std;

typedef long long LL;

typedef vector<int> VI;

typedef pair<int,int> PII;

inline LL read(){
    LL res=0, f=1;char ch=getchar();
    while(ch<'0'|ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){res=res*10+ch-'0';ch=getchar();}
    return res*f;
}

const int MAXN = 200'005;

const int MOD = 1000000007;

void addmod(int& a, int b){a+=b;if(a>=MOD)a-=MOD;}
int mulmod(int a, int b){return 1ll*a*b%MOD;}

template<typename T>
void chmin(T& a, T b){if(a>b)a=b;}

template<typename T>
void chmax(T& a, T b){if(b>a)a=b;}

LL n;

LL sz[MAXN], sum[MAXN], ans[MAXN];
LL col[MAXN];

#define go(e,u) for(int e=head[u];e;e=Next[e])
int to[MAXN<<1],Next[MAXN<<1],head[MAXN],tol;

void add_edge(int u,int v){
    Next[++tol]=head[u];to[tol]=v;head[u]=tol;
    Next[++tol]=head[v];to[tol]=u;head[v]=tol;
}

LL calc(LL x){return x*(x+1)/2;}

void dfs(int u,int f){
    int c=col[u];
    sz[u]=1;LL o=sum[c];
    go(e,u){
        int v=to[e];
        if(v==f)continue;
        LL t=sum[c];
        dfs(v,u);
        ans[c]-=calc(sz[v]-(sum[c]-t));
        sz[u]+=sz[v];
    }
    sum[col[u]]=o+sz[u];
}

int main(){
    n=read();
    for(int i=1;i<=n;++i)col[i]=read(),ans[i]=n*(n+1)/2;

    for(int i=1,u,v;i<n;++i){
        u=read();
        v=read();
        add_edge(u,v);
    }

    dfs(1,0);

    for(int i=1;i<=n;++i){
        LL t=n-sum[i];
        ans[i]-=calc(t);
        cout<<ans[i]<<endl;
    }

    return 0;
}
View Code

 

posted @ 2020-04-25 16:52  John_Ran  阅读(325)  评论(0编辑  收藏  举报