LG P6478 [NOI Online #2 提高组] 游戏
Description
小 A 和小 B 正在玩一个游戏:有一棵包含 $n=2m$ 个点的有根树(点从 $1\sim n$ 编号),它的根是 $1$ 号点,初始时两人各拥有 $m$ 个点。游戏的每个回合两人都需要选出一个自己拥有且之前未被选过的点,若对手的点在自己的点的子树内,则该回合自己获胜;若自己的点在对方的点的子树内,该回合自己失败;其他情况视为平局。游戏共进行 $m$ 回合。
作为旁观者的你只想知道,在他们随机选点的情况下,第一次非平局回合出现时的回合数的期望值。
为了计算这个期望,你决定对于 $k=0,1,2,\cdots,m$,计算出非平局回合数为 $k$ 的情况数。两种情况不同当且仅当存在一个小 A 拥有的点 $x$,小 B 在 $x$ 被小 A 选择的那个回合所选择的点不同。
由于情况总数可能很大,你只需要输出答案对 $998244353$ 取模后的结果。
Solution
设$f_i$表示钦定有$i$对父子关系,其余的任意选择的方案数,$g_i$为刚好有$i$对父子关系的方案数
那么有$$f_i =\sum_{j=1}^m g_j \binom{j}{i}$$
由二项式反演得
$$g_i = \sum_{j=i}^m (-1)^{j-i} \binom{j}{i} f_j$$
题中所求为$g$,所以求得$f$即可
$f$可以用树上背包求得,背包得到的结果需要再乘上$(m-i)!$
#include<iostream> #include<cstring> #include<cstdio> #include<cmath> using namespace std; int n,head[5005],tot,siz[5005][2],s[5005],m; long long dp[5005][5005],temp[5005],fac[5005]={1},inv[5005]; char str[5005]; const long long mod=998244353; struct Edge { int to,nxt; }edge[10005]; inline int read() { int w=0,f=1; char ch=0; while(ch<'0'||ch>'9'){if(ch=='-') f=-1; ch=getchar();} while(ch>='0'&&ch<='9')w=(w<<1)+(w<<3)+ch-'0',ch=getchar(); return w*f; } long long ksm(long long a,long long p) { long long ret=1; while(p) { if(p&1) (ret*=a)%=mod; (a*=a)%=mod,p>>=1; } return ret; } long long C(int x,int y) { if(y>x||x<0||y<0) return 0; return fac[x]*inv[y]%mod*inv[x-y]%mod; } void dfs(int k,int f) { if(str[k]=='1') siz[k][1]=1; else siz[k][0]=1; s[k]=1; dp[k][0]=1; for(int i=head[k];i;i=edge[i].nxt) { int v=edge[i].to; if(v!=f) { dfs(v,k),memset(temp,0,sizeof(temp)); for(int x=0;x<=min(siz[k][0],siz[k][1]);x++) for(int y=0;y<=min(siz[v][0],siz[v][1]);y++) (temp[x+y]+=dp[k][x]*dp[v][y]%mod)%=mod; siz[k][0]+=siz[v][0],siz[k][1]+=siz[v][1],s[k]+=s[v]; for(int x=0;x<=min(siz[k][0],siz[k][1]);x++) dp[k][x]=temp[x]; } } if(str[k]=='1') for(int i=siz[k][0]-1;~i;i--) (dp[k][i+1]+=dp[k][i]*1ll*(siz[k][0]-i)%mod)%=mod; else for(int i=siz[k][1]-1;~i;i--) (dp[k][i+1]+=dp[k][i]*1ll*(siz[k][1]-i)%mod)%=mod; } int main() { n=read(),m=n>>1,scanf("%s",str+1); for(int i=1;i<=n;i++) fac[i]=fac[i-1]*i%mod; inv[n]=ksm(fac[n],mod-2); for(int i=n-1;~i;i--) inv[i]=inv[i+1]*(i+1)%mod; for(int i=1;i<n;i++) { int u=read(),v=read(); edge[++tot]=(Edge){v,head[u]},head[u]=tot,edge[++tot]=(Edge){u,head[v]},head[v]=tot; } dfs(1,0); for(int i=0;i<=m;i++) (dp[1][i]*=fac[m-i])%=mod; for(int i=0;i<=m;i++) { long long g=0; for(int j=i;j<=m;j++) if((j-i)&1) ((g-=C(j,i)*dp[1][j]%mod)+=mod)%=mod; else (g+=C(j,i)*dp[1][j]%mod)%=mod; printf("%lld\n",g); } return 0; }