[JSOI2019]神经网络 题解
Solution
我们可以发现,这其实相当于把每棵树拆成若干条链拼起来,开头必须是第一棵树 \(1\) 所在的链,要求相邻两个不是同一棵树,并且最后一个不是第 \(1\) 棵树的链,对于长度 \(> 1\) 的链我们会有 \(2\) 的贡献,求贡献和。注意,我们并不要求第 \(1\) 条链必须以 \(1\) 开头,所以方案还是正常考虑。
可以发现如果没有相邻两个不是同一棵树内部的限制,那么我们可以直接 EGF 暴力乘起来。那么我们可以考虑容斥。对于一棵树(不是第 \(1\) 棵),设 \(f_i\) 表示拆成 \(i\) 条链的贡献,那么其 EGF 为:
\[\sum_{i=1}^{n} f_ii!\sum_{j=1}^{i} \binom{i-1}{j-1}(-1)^{i-j}\frac{x^j}{j!}
\]
可以发现对于第 \(1\) 棵,我们先不考虑最后一条链的限制,因为节点 \(1\) 所在的链要在第一个,可以得到其 EGF 为:
\[\sum_{i=1}^{n} f_i(i-1)!\sum_{j=1}^{i} \binom{i-1}{j-1}(-1)^{i-j}\frac{x^{j-1}}{(j-1)!}
\]
再考虑最后一条链的限制,发现直接减去最后一条链来自于第一棵树的 EGF 即可,其为:
\[\sum_{i=1}^{n} f_i(i-1)!\sum_{j=1}^{i} \binom{i-1}{j-1}(-1)^{i-j}\frac{x^{j-2}}{(j-2)!}
\]
然后直接暴力多项式即可。复杂度 \(\Theta(n^2)\)。
Code
#include <bits/stdc++.h>
using namespace std;
#define Int register int
#define mod 998244353
#define MAXN 5005
// char buf[1<<21],*p1=buf,*p2=buf;
// #define getchar() (p1==p2 && (p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
template <typename T> inline void read (T &t){t = 0;char c = getchar();int f = 1;while (c < '0' || c > '9'){if (c == '-') f = -f;c = getchar();}while (c >= '0' && c <= '9'){t = (t << 3) + (t << 1) + c - '0';c = getchar();} t *= f;}
template <typename T,typename ... Args> inline void read (T &t,Args&... args){read (t);read (args...);}
template <typename T> inline void write (T x){if (x < 0){x = -x;putchar ('-');}if (x > 9) write (x / 10);putchar (x % 10 + '0');}
template <typename T> inline void chkmax (T &a,T b){a = max (a,b);}
template <typename T> inline void chkmin (T &a,T b){a = min (a,b);}
vector <int> g[MAXN];
int n,M,siz[MAXN],f[MAXN][MAXN][3],h[MAXN][3],tmp[MAXN][3];
int mul (int a,int b){return 1ll * a * b % mod;}
int dec (int a,int b){return a >= b ? a - b : a + mod - b;}
int add (int a,int b){return a + b >= mod ? a + b - mod : a + b;}
int qkpow (int a,int b){
int res = 1;for (;b;b >>= 1,a = mul (a,a)) if (b & 1) res = mul (res,a);
return res;
}
void Add (int &a,int b){a = add (a,b);}
void Sub (int &a,int b){a = dec (a,b);}
void dfs (int u,int fa){
siz[u] = 1;
for (Int v : g[u]) if (v ^ fa) dfs (v,u);
memset (h,0,sizeof (h)),h[1][0] = 1;
for (Int v : g[u]) if (v ^ fa){
memset (tmp,0,sizeof (tmp));
for (Int s1 = 1;s1 <= siz[u];++ s1)
for (Int t1 = 0;t1 < 3;++ t1)
for (Int s2 = 1;s2 <= siz[v];++ s2)
for (Int t2 = 0;t2 < 3;++ t2){
if (t1 < 2 && t2 < 2) Add (tmp[s1 + s2 - 1][t1 + 1],mul (h[s1][t1],f[v][s2][t2]));
Add (tmp[s1 + s2][t1],mul (h[s1][t1],(t2 ? 2 : 1) * f[v][s2][t2]));
}
memcpy (h,tmp,sizeof (h)),siz[u] += siz[v];
}
for (Int i = 1;i <= siz[u];++ i) for (Int t = 0;t < 3;++ t) f[u][i][t] = h[i][t];
}
#define poly vector<int>
#define SZ(A) ((A).size())
poly operator * (poly A,poly B){
poly New;New.resize (SZ(A) + SZ(B) - 1);
for (Int i = 0;i < SZ(A);++ i) for (Int j = 0;j < SZ(B);++ j) Add (New[i + j],mul (A[i],B[j]));
return New;
}
int fac[MAXN],ifac[MAXN];
int binom (int a,int b){return a >= b ? mul (fac[a],mul (ifac[b],ifac[a - b])) : 0;}
signed main(){
read (M);int up = 5e3;
fac[0] = 1;for (Int i = 1;i <= up;++ i) fac[i] = mul (fac[i - 1],i);
ifac[up] = qkpow (fac[up],mod - 2);for (Int i = up;i;-- i) ifac[i - 1] = mul (ifac[i],i);
poly S(1,1);
for (Int T = 1;T <= M;++ T){
read (n);
for (Int x = 1;x <= n;++ x) g[x].clear ();
for (Int i = 2,u,v;i <= n;++ i) read (u,v),g[u].push_back (v),g[v].push_back (u);
dfs (1,0);poly St;St.resize (n + 1);
if (T == 1){
for (Int i = 1;i <= n;++ i){
int fi = add (f[1][i][0],mul (2,add (f[1][i][1],f[1][i][2])));
for (Int j = 1;j <= i;++ j){
int tmp = mul (mul (fi,fac[i - 1]),mul (binom (i - 1,j - 1),mul (i - j & 1 ? mod - 1 : 1,ifac[j - 1])));
Add (St[j - 1],tmp);
if (j > 1){
tmp = mul (mul (fi,fac[i - 1]),mul (binom (i - 1,j - 1),mul (i - j & 1 ? mod - 1 : 1,ifac[j - 2])));
Sub (St[j - 2],tmp);
}
}
}
S = S * St;
}
else{
for (Int i = 1;i <= n;++ i){
int fi = add (f[1][i][0],mul (2,add (f[1][i][1],f[1][i][2])));
for (Int j = 1;j <= i;++ j){
int tmp = mul (mul (fi,fac[i]),mul (binom (i - 1,j - 1),mul (i - j & 1 ? mod - 1 : 1,ifac[j])));
Add (St[j],tmp);
}
}
S = S * St;
}
}
int ans = 0;
for (Int i = 0;i < SZ(S);++ i) Add (ans,mul (fac[i],S[i]));
write (ans),putchar ('\n');
return 0;
}