Comet OJ - Contest #6 D. 另一道树题 并查集 + 思维 + 计数

Code: 

#include <cstdio>
#include <algorithm>
#include <cstring>
#include <vector>
#define setIO(s) freopen(s".in","r",stdin) 
#define maxn 200004 
#define mod 998244353 
#define ll long long 
using namespace std; 
inline ll add(ll a,ll b) {
    return (a+b)%mod; 
}
inline ll de(ll a,ll b) {
    return ((a%mod)-(b%mod)+mod)%mod; 
}    
int n,edges; 
int hd[maxn],to[maxn],nex[maxn],dep[maxn],f[maxn],siz[maxn],son[maxn],top[maxn],p[maxn],sz[maxn];           
ll inv[maxn],mul,g[maxn],ans;        
struct Node {
    int u,v; 
    Node(int u=0,int v=0):u(u),v(v){} 
}; 
vector<Node>vi[maxn];   
vector<int>G[maxn];          
ll qpow(ll base,ll k) {
    ll re=1; 
    while(k) {
        if(k&1) re=re*base%mod;       
        base=base*base%mod; 
        k>>=1; 
    }
    return re;  
}     
void add(int u,int v) {
    nex[++edges]=hd[u],hd[u]=edges,to[edges]=v; 
}
void dfs1(int u) {
    dep[u]=dep[f[u]]+1,siz[u]=1,G[dep[u]].push_back(u);              
    for(int i=hd[u];i;i=nex[i]) {
        int v=to[i];  
        if(v==f[u]) continue; 
        dfs1(v),siz[u]+=siz[v]; 
        if(siz[v]>siz[son[u]]) son[u]=v;             
    } 
}  
void dfs2(int u,int tp) {
    top[u]=tp; 
    if(son[u]) dfs2(son[u],tp); 
    for(int i=hd[u];i;i=nex[i]) {
        int v=to[i]; 
        if(v==f[u]||v==son[u]) continue;   
        dfs2(v,v); 
    }
} 
int LCA(int x,int y) {
    while(top[x]^top[y]) {
        dep[top[x]]>dep[top[y]]?x=f[top[x]]:y=f[top[y]]; 
    }
    return dep[x]<dep[y]?x:y;    
} 
void init() {
    for(int i=0;i<maxn;++i) p[i]=i,sz[i]=1;     
}
int find(int x) {
    return p[x]==x?x:p[x]=find(p[x]); 
}
void merge(int x,int y) {
    int fx=find(x),fy=find(y); 
    if(fx!=fy) { 
        mul=mul*inv[sz[fx]+1]%mod*inv[sz[fy]+1]%mod; 
        p[fx]=fy, sz[fy]+=sz[fx];           
        mul=mul*(sz[fy]+1)%mod;            
    }
}  
void Initialize() {
    init(),inv[1]=1,mul=qpow(2,n);       
    for(int i=2;i<maxn;++i) inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod;       
} 
int main() {
    // setIO("input");    
    scanf("%d",&n); 
    for(int i=2;i<=n;++i) scanf("%d",&f[i]), add(f[i],i);        
    Initialize(),dfs1(1),dfs2(1,1);          
    for(int i=2;i<=n;++i) vi[dep[i]-1].push_back(Node(i,f[i]));    
    for(int i=1;i<=n;++i) {
        for(int j=1;j<G[i].size();++j) { 
            vi[dep[G[i][j]] - dep[LCA(G[i][j], G[i][j-1])]].push_back(Node(G[i][j], G[i][j-1]));    
        }
    }  
    ans=de(mul,n+1);                        
    for(int i=1;i<=n;++i) {
        for(int j=0;j<vi[i].size();++j) 
            merge(vi[i][j].u, vi[i][j].v); 
        ans=add(ans,de(mul,n+1));                  
    }
    printf("%lld\n",ans);      
    return 0; 
}

  

posted @ 2019-07-25 08:53  EM-LGH  阅读(188)  评论(0编辑  收藏  举报