[JSOI2019]神经网络
ddy讲的牛逼题。
由于树和树之间是完全图,所以我们要做的就是把树拆成一堆路径,之后把这些路径合并起来,就能得到哈密顿回路了;
所以首先对每棵树求一个链划分,设\(dp_{i,j,0/1/2}\)表示在子树\(i\)中划分出了\(j\)条链,\(0\)表示点\(i\)已经划分好了,\(1\)表示点\(i\)自己在一条链中,\(2\)表示点\(i\)在一条还能继续加点的长度大于\(1\)的链中,注意到长度大于\(1\)的链有两个方向,计算贡献的时候需要乘\(2\);大力树上背包即可,复杂度\(O(n^2)\)。
再来考虑合并的问题,现在我们把问题转化成了一共有\(n\)种颜色,每种颜色的有\(a_i\)个可区分的点,求使得不存在两个相邻同色点的环排列的个数;
首先考虑不是环的情况,考虑容斥,我们枚举一下第\(i\)中颜色至多分成\(b_i\)段,大概是
\[\sum \frac{(\sum b_i)!}{\prod b_i!}\prod_{i=1}^n (-1)^{a_i-b_i}\binom{a_i-1}{b_i-1}a_i!
\]
\((-1)^{a_i-b_i}\)是容斥系数,分成\(b_i\)段就用组合数插板一下;由于乘上\(a_i!\),所以所有颜色排列在一起的时候需要保证相对顺序,就是一个有重复元素的排列问题。
不难发现\(\frac{1}{b_i!}\)可以拆到里面来,于是不难想到搞一个\(\rm EGF\);
于是某一棵树的\(\rm EGF\)就是
\[\sum_{i=1}^nf_ii!\sum_{j=1}^i (-1)^{i-j}\binom{i-1}{j-1}\frac{x^j}{j!}
\]
\(f_i\)是把这棵树拆成\(i\)条链的贡献。
再来考虑环的情况,不妨钦定第一棵树的\(1\)号节点作为开头,于是第一棵树的生成函数就是
\[\sum_{i=1}^nf_i(i-1)!\sum_{j=1}^i (-1)^{i-j}\binom{i-1}{j-1}\frac{x^{j-1}}{(j-1)!}
\]
由于首位不能是都是第一棵树,于是我们直接减掉这种情况,直接钦定开头是第一棵树的\(1\)号节点,结尾是第一棵树的某个节点,就有
\[\sum_{i=1}^nf_i(i-1)!\sum_{j=1}^i (-1)^{i-j}\binom{i-1}{j-1}\frac{x^{j-2}}{(j-2)!}
\]
最后把所有\(\rm EGF\)大力卷起来就好了;
代码
#include<bits/stdc++.h>
#define re register
#define LL long long
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
inline int read() {
char c=getchar();int x=0;while(c<'0'||c>'9') c=getchar();
while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
const int maxn=5e3+5;const int mod=998244353;
inline int dqm(int x) {return x<0?x+mod:x;}
inline int qm(int x) {return x>=mod?x-mod:x;}
inline int ksm(int a,int b) {
int S=1;for(;b;b>>=1,a=1ll*a*a%mod)if(b&1)S=1ll*S*a%mod;return S;
}
struct E{int v,nxt;}e[maxn<<1];
int fac[maxn],ifac[maxn],lim=5000;
int m,n,sum,num,head[maxn],ans[maxn],f[maxn],g[maxn];
int dp[maxn][maxn][3],tmp[maxn][3],sz[maxn];
inline void add(int x,int y) {
e[++num].v=y;e[num].nxt=head[x];head[x]=num;
}
void dfs(int x,int fa) {
sz[x]=1;dp[x][0][1]=1;
for(re int i=head[x];i;i=e[i].nxt) {
if(e[i].v==fa) continue;
dfs(e[i].v,x);int v=e[i].v;
for(re int j=0;j<=sz[v];j++)
for(re int k=0;k<=sz[x];++k) {
if(dp[x][k][0]) {
tmp[k+j][0]=qm(tmp[j+k][0]+1ll*dp[v][j][0]*dp[x][k][0]%mod);
tmp[k+j+1][0]=qm(tmp[j+k+1][0]+1ll*dp[v][j][1]*dp[x][k][0]%mod);
tmp[k+j+1][0]=qm(tmp[j+k+1][0]+2ll*dp[v][j][2]*dp[x][k][0]%mod);
}
if(dp[x][k][1]) {
tmp[k+j][1]=qm(tmp[j+k][1]+1ll*dp[v][j][0]*dp[x][k][1]%mod);
tmp[k+j+1][1]=qm(tmp[j+k+1][1]+1ll*dp[v][j][1]*dp[x][k][1]%mod);
tmp[k+j][2]=qm(tmp[k+j][2]+1ll*dp[v][j][1]*dp[x][k][1]%mod);
tmp[k+j+1][1]=qm(tmp[k+j+1][1]+2ll*dp[v][j][2]*dp[x][k][1]%mod);
tmp[k+j][2]=qm(tmp[k+j][2]+1ll*dp[v][j][2]*dp[x][k][1]%mod);
}
if(dp[x][k][2]) {
tmp[k+j][2]=qm(tmp[j+k][2]+1ll*dp[v][j][0]*dp[x][k][2]%mod);
tmp[k+j+1][2]=qm(tmp[j+k+1][2]+1ll*dp[v][j][1]*dp[x][k][2]%mod);
tmp[k+j+1][0]=qm(tmp[k+j+1][0]+2ll*dp[v][j][1]*dp[x][k][2]%mod);
tmp[k+j+1][2]=qm(tmp[k+j+1][2]+2ll*dp[v][j][2]*dp[x][k][2]%mod);
tmp[k+j+1][0]=qm(tmp[k+j+1][0]+2ll*dp[v][j][2]*dp[x][k][2]%mod);
}
}
sz[x]+=sz[e[i].v];
for(re int j=0;j<=sz[x];++j)
for(re int k=0;k<3;k++) dp[x][j][k]=tmp[j][k],tmp[j][k]=0;
}
}
inline int C(int n,int m) {return 1ll*fac[n]*ifac[n-m]%mod*ifac[m]%mod;}
inline void solve(int id) {
for(re int i=1;i<=n;i++)
for(re int j=0;j<=sz[i];++j)
dp[i][j][0]=dp[i][j][1]=dp[i][j][2]=0;
for(re int i=1;i<=n;i++)head[i]=0;
n=read();sum+=n;num=0;
for(re int x,y,i=1;i<n;i++)
x=read(),y=read(),add(x,y),add(y,x);
dfs(1,0);g[0]=0;
for(re int i=1;i<=n;i++) {
f[i]=dp[1][i][0];g[i]=0;
f[i]=qm(f[i]+dp[1][i-1][1]);
f[i]=qm(f[i]+2ll*dp[1][i-1][2]%mod);
}
if(!id) {
for(re int i=1;i<=n;i++) {
if(!f[i]) continue;int v=1ll*f[i]*fac[i]%mod;
for(re int j=1;j<=i;++j) {
if((i-j)&1) g[j]=dqm(g[j]-1ll*v*C(i-1,j-1)%mod*ifac[j]%mod);
else g[j]=qm(g[j]+1ll*v*C(i-1,j-1)%mod*ifac[j]%mod);
}
}
}
else {
for(re int i=1;i<=n;i++) {
if(!f[i]) continue;int v=1ll*f[i]*fac[i-1]%mod;
for(re int j=1;j<=i;j++) {
if((i-j)&1) g[j-1]=dqm(g[j-1]-1ll*v*C(i-1,j-1)%mod*ifac[j-1]%mod);
else g[j-1]=qm(g[j-1]+1ll*v*C(i-1,j-1)%mod*ifac[j-1]%mod);
}
for(re int j=2;j<=i;j++) {
if((i-j)&1) g[j-2]=qm(g[j-2]+1ll*v*C(i-1,j-1)%mod*ifac[j-2]%mod);
else g[j-2]=dqm(g[j-2]-1ll*v*C(i-1,j-1)%mod*ifac[j-2]%mod);
}
}
}
for(re int nw=0,i=sum;i>=0;ans[i]=nw,nw=0,i--)
for(re int j=0;j<=i&&j<=n;j++) nw=qm(nw+1ll*ans[i-j]*g[j]%mod);
}
int main() {
m=read();fac[0]=ifac[0]=1;ans[0]=1;
for(re int i=1;i<=lim;i++)fac[i]=1ll*fac[i-1]*i%mod;
ifac[lim]=ksm(fac[lim],mod-2);
for(re int i=lim-1;i;i--)ifac[i]=1ll*ifac[i+1]*(i+1)%mod;
for(re int i=1;i<=m;++i)solve(i==1);int cnt=0;
for(re int i=0;i<=sum;i++)cnt=qm(cnt+1ll*ans[i]*fac[i]%mod);
printf("%d\n",cnt);
return 0;
}