Luogu P5333 [JSOI2019]神经网络

Link
一条Hamilton回路可以被拆分成由若干条树上路径组成的环,其中相邻两条树上路径不能同属于一棵树。
假如我们求出了将一棵树分为若干链的方案数,那么剩下的就是求给环染色,相邻位置颜色不同的方案数。
第一部分可以用树形背包简单求出,设\(f_{u,i,0/1/2}\)表示\(u\)的子树内选了\(i\)条链,\(u\)的状态为:在一条进入了两个不同的子树的链上/只有一个单点/在一条至少有两个点的单链上,转移分类讨论一下就行了。(一条长度不小于\(2\)的链应该算两遍!!1)
现在我们已经求出了\(f_k\)表示选出\(k\)条链的带权方案数,直接EGF组合的话会出现环上相邻两条链属于同一棵树的情况。
考虑容斥,钦定最后环上属于该棵树的链构成\(j\)个极长连续段,那么\(f_k\)对该项的贡献为\((-1)^{k-j}{k-1\choose j-1}f_kk!\),所以EGF为

\[\widehat f(x)=\sum\limits_{i=1}^nf_ii!\sum\limits_{j=1}^i(-1)^{i-j}{i-1\choose j-1}\frac{x^j}{j!} \]

对于第一棵树而言,限制\(1\)必须是环的开头,同时首尾不能相同,推完式子可以发现EGF恰好是\(\frac{\widehat f(x)}x\)
最后把所有EGF乘起来就可以得到答案的EGF了。
注意这个做法在\(m=1\)的时候有问题的。

#include<cstdio>
#include<vector>
#include<cstring>
#include<algorithm>
using i64=long long;
const int N=5007,P=998244353;
int size[N];i64 s[N][N],f[N][N][3];std::vector<int>e[N];
int read(){int x;scanf("%d",&x);return x;}
void inc(i64&a,i64 b){a+=b-P,a+=a>>63&P;}
void dec(i64&a,i64 b){a-=b,a+=a>>63&P;}
struct poly{int deg;i64 a[N];poly(int n){deg=n,memset(a,0,8*n+8);}i64&operator[](const int&x){return a[x];}};
poly operator*(poly f,poly g)
{
    poly a(f.deg+g.deg);
    for(int i=0;i<=f.deg;++i) for(int j=0;j<=g.deg;++j) inc(a[i+j],f[i]*g[j]%P);
    return a;
}
void dfs(int u,int fa)
{
    static i64 t[N][3];
    size[u]=1,memset(f[u],0,48),f[u][0][1]=1;
    for(int v:e[u])
    {
	if(v==fa) continue;
	dfs(v,u),memset(t,0,24*(size[u]+size[v]+1));
	for(int i=0;i<=size[u];++i)
	    for(int j=0;j<=size[v];++j)
	    {
		for(int k=0;k<3;++k) inc(t[i+j][k],f[u][i][k]*f[v][j][0]%P);
		inc(t[i+j][2],f[u][i][1]*f[v][j][1]%P),inc(t[i+j+1][0],2*f[u][i][2]*f[v][j][1]%P);
	    }
	size[u]+=size[v],memcpy(f[u],t,24*(size[u]+1));
    }
    for(int i=0;i<=size[u];++i) (f[u][i+1][0]+=2*f[u][i][2]+f[u][i][1])%=P,inc(f[u][i][1],f[u][i][2]);
}
poly solve()
{
    int n=read();poly a(n);
    for(int i=1;i<=n;++i) e[i].clear();
    for(int i=1,u,v;i<n;++i) u=read(),v=read(),e[u].push_back(v),e[v].push_back(u);
    dfs(1,0);
    for(int i=1;i<=n;++i) a[i]=f[1][i][0];
    for(int i=n;i;--i) for(int j=i+1;j<=n;++j) dec(a[i],a[j]*s[j][i]%P);
    return a;
}
int main()
{
    int m=read();poly ans(0);i64 sum=0,fac=1;ans[0]=s[1][1]=1;
    for(int i=2;i<=5000;++i) for(int j=1;j<=i;++j) s[i][j]=(s[i-1][j-1]+(i-1+j)*s[i-1][j])%P;
    for(int i=1;i<=m;++i) ans=ans*solve();
    for(int i=1;i<=ans.deg;++i) inc(sum,ans[i]*fac%P),fac=fac*i%P;
    printf("%lld",sum);
}
posted @ 2020-06-05 17:32  Shiina_Mashiro  阅读(244)  评论(0编辑  收藏  举报