Comet OJ - Contest #6 D. 另一道树题 并查集 + 思维 + 计数
Code:
#include <cstdio> #include <algorithm> #include <cstring> #include <vector> #define setIO(s) freopen(s".in","r",stdin) #define maxn 200004 #define mod 998244353 #define ll long long using namespace std; inline ll add(ll a,ll b) { return (a+b)%mod; } inline ll de(ll a,ll b) { return ((a%mod)-(b%mod)+mod)%mod; } int n,edges; int hd[maxn],to[maxn],nex[maxn],dep[maxn],f[maxn],siz[maxn],son[maxn],top[maxn],p[maxn],sz[maxn]; ll inv[maxn],mul,g[maxn],ans; struct Node { int u,v; Node(int u=0,int v=0):u(u),v(v){} }; vector<Node>vi[maxn]; vector<int>G[maxn]; ll qpow(ll base,ll k) { ll re=1; while(k) { if(k&1) re=re*base%mod; base=base*base%mod; k>>=1; } return re; } void add(int u,int v) { nex[++edges]=hd[u],hd[u]=edges,to[edges]=v; } void dfs1(int u) { dep[u]=dep[f[u]]+1,siz[u]=1,G[dep[u]].push_back(u); for(int i=hd[u];i;i=nex[i]) { int v=to[i]; if(v==f[u]) continue; dfs1(v),siz[u]+=siz[v]; if(siz[v]>siz[son[u]]) son[u]=v; } } void dfs2(int u,int tp) { top[u]=tp; if(son[u]) dfs2(son[u],tp); for(int i=hd[u];i;i=nex[i]) { int v=to[i]; if(v==f[u]||v==son[u]) continue; dfs2(v,v); } } int LCA(int x,int y) { while(top[x]^top[y]) { dep[top[x]]>dep[top[y]]?x=f[top[x]]:y=f[top[y]]; } return dep[x]<dep[y]?x:y; } void init() { for(int i=0;i<maxn;++i) p[i]=i,sz[i]=1; } int find(int x) { return p[x]==x?x:p[x]=find(p[x]); } void merge(int x,int y) { int fx=find(x),fy=find(y); if(fx!=fy) { mul=mul*inv[sz[fx]+1]%mod*inv[sz[fy]+1]%mod; p[fx]=fy, sz[fy]+=sz[fx]; mul=mul*(sz[fy]+1)%mod; } } void Initialize() { init(),inv[1]=1,mul=qpow(2,n); for(int i=2;i<maxn;++i) inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod; } int main() { // setIO("input"); scanf("%d",&n); for(int i=2;i<=n;++i) scanf("%d",&f[i]), add(f[i],i); Initialize(),dfs1(1),dfs2(1,1); for(int i=2;i<=n;++i) vi[dep[i]-1].push_back(Node(i,f[i])); for(int i=1;i<=n;++i) { for(int j=1;j<G[i].size();++j) { vi[dep[G[i][j]] - dep[LCA(G[i][j], G[i][j-1])]].push_back(Node(G[i][j], G[i][j-1])); } } ans=de(mul,n+1); for(int i=1;i<=n;++i) { for(int j=0;j<vi[i].size();++j) merge(vi[i][j].u, vi[i][j].v); ans=add(ans,de(mul,n+1)); } printf("%lld\n",ans); return 0; }