[JSOI2019]神经网络(树形DP+容斥+生成函数)
首先可以把题目转化一下:把树拆成若干条链,每条链的颜色为其所在的树的颜色,然后排放所有的链成环,求使得相邻位置颜色不同的排列方案数。
然后本题分为两个部分:将一棵树分为1~n条不相交的链的方案数;将这些链安排顺序使得不存在两条相邻的链来自同一棵树。
第一部分显然可以O(n2)树形DP,f[i][j][0/1/2]表示i及其子树j条链,i向儿子连出0/1/2条边的方案数,然后直接背包DP即可。看似O(n3)的树形背包DP其实是O(n2)的。证明复杂度:其实DP时只循环到sz[u]/sz[v]即可,然后可以把每个转移视为儿子v内子树的每个节点和节点u内v外节点组成的点对,于是全部DP完就是枚举了所有的点对,复杂度显然O(n2)。
第二部分,考虑n个点的树划分成i条链的方案是f[i],如果不考虑环只考虑链其对应的指数生成函数为Σf[i]i!(Σ(-1)i-jC(i-1,i-j)xj/j!),其中i∈[1,n],j∈[1,i]。拓展到环上,钦定一棵树作为开头,如果该颜色有i条链,则被算了i次,然后其指数生成函数为:Σf[i](i-1)!(Σ(-1)i-jC(i-1,i-j)xj-1/(j-1)!),其中i∈[1,n],j∈[1,i]。减去首尾同色后,生成函数是这样的:Σf[i](i-1)!(Σ(-1)i-jC(i-1,i-j)xj-2/(j-2)!),其中i∈[2,n],j∈[2,i]。然后暴力卷积即可。
#include<bits/stdc++.h> using namespace std; const int N=5005,mod=998244353; int n,m,sum,ans,fac[N],inv[N],sz[N],f[N][N][3],g[N],tmp[N][3],dp[310][N],b[N]; vector<int>G[N]; int qpow(int a,int b) { int ret=1; while(b) { if(b&1)ret=1ll*ret*a%mod; a=1ll*a*a%mod,b>>=1; } return ret; } void dfs(int u,int fa) { sz[u]=1,f[u][1][0]=1; for(int i=0;i<G[u].size();i++) if(G[u][i]!=fa) { int v=G[u][i]; dfs(v,u); for(int j=0;j<=sz[u]+sz[v];j++)tmp[j][0]=tmp[j][1]=tmp[j][2]=0; for(int j=1;j<=sz[u];j++) for(int k=1;k<=sz[v];k++) { tmp[j+k][0]=(tmp[j+k][0]+1ll*f[u][j][0]*(f[v][k][0]+2ll*f[v][k][1]+2ll*f[v][k][2]))%mod; tmp[j+k-1][1]=(tmp[j+k-1][1]+1ll*f[u][j][0]*(f[v][k][0]+f[v][k][1]))%mod; tmp[j+k][1]=(tmp[j+k][1]+1ll*f[u][j][1]*(f[v][k][0]+2ll*f[v][k][1]+2ll*f[v][k][2]))%mod; tmp[j+k-1][2]=(tmp[j+k-1][2]+1ll*f[u][j][1]*(f[v][k][0]+f[v][k][1]))%mod; tmp[j+k][2]=(tmp[j+k][2]+1ll*f[u][j][2]*(f[v][k][0]+2ll*f[v][k][1]+2ll*f[v][k][2]))%mod; } sz[u]+=sz[v]; for(int j=1;j<=sz[u];j++)f[u][j][0]=tmp[j][0],f[u][j][1]=tmp[j][1],f[u][j][2]=tmp[j][2]; } } int C(int a,int b){return a<b?0:1ll*fac[a]*inv[b]%mod*inv[a-b]%mod;} int S(int a,int b){return (!a&&!b)?1:1ll*fac[a]*C(a-1,a-b)%mod;} int main() { fac[0]=1;for(int i=1;i<=5000;i++)fac[i]=1ll*fac[i-1]*i%mod; for(int i=0;i<=5000;i++)inv[i]=qpow(fac[i],mod-2); scanf("%d",&m); dp[0][0]=1; for(int p=1;p<=m;p++) { scanf("%d",&n); for(int i=1;i<=n;i++)G[i].clear(); for(int i=1,x,y;i<n;i++)scanf("%d%d",&x,&y),G[x].push_back(y),G[y].push_back(x); for(int i=1;i<=n;i++) for(int j=1;j<=n;j++) f[i][j][0]=f[i][j][1]=f[i][j][2]=0; dfs(1,0); memset(g,0,sizeof g); for(int i=1;i<=n;i++)g[i]=(f[1][i][0]+2ll*f[1][i][1]+2ll*f[1][i][2])%mod; if(p!=m) { memset(b,0,sizeof b); for(int j=1;j<=n;j++) if(g[j])for(int k=0,t=1;k<=j;k++,t=mod-t) b[j-k]=(b[j-k]+1ll*t*S(j,j-k)%mod*g[j])%mod; for(int i=0;i<=sum;i++) if(dp[p-1][i])for(int j=0;j<=n;j++) dp[p][i+j]=(dp[p][i+j]+1ll*C(i+j,j)*b[j]%mod*dp[p-1][i])%mod; } else{ memset(b,0,sizeof b); for(int j=1;j<=n;j++) if(g[j])for(int k=0,t=1;k<j;k++,t=mod-t) b[j-1-k]=(b[j-1-k]+1ll*t*S(j-1,j-k-1)%mod*g[j])%mod; for(int i=0;i<=sum;i++) if(dp[p-1][i])for(int j=0;j<=n;j++) ans=(ans+1ll*C(i-2+j,j)*b[j]%mod*dp[p-1][i])%mod; } sum+=n; } printf("%d",ans); }