[清华集训2017]生成树计数
在一个\(s\)个点的图中,存在\(s-n\)条边,使图中形成了\(n\)个连通块,第\(i\)个连通块中有\(a_i\)个点。
现在我们需要再连接\(n-1\)条边,使该图变成一棵树。对一种连边方案,设原图中第\(i\)个连通块连出了\(d_i\)条边,那么这棵树\(T\)的价值为:
你的任务是求出所有可能的生成树的价值之和,对\(998244353\)取模。
(可能只有我没读出来题目说连通块内的连边方式不计。)
树和每个点的度数可以联想到\(prufer\)序列。那么设\(c_i\)为第\(i\)个点在\(prufer\)序列中出现的次数,则\(c_i=d_i-1\)。考虑对于一个确定的序列\(c_i\),它对答案的贡献就是
第一项是这个序列对应的有标号无根树个树;第二个是因为每个连通块中的点可以任意分配这个连通块的出边;后面两个是题面定义的价值。
那么有一个暴力思路就是递推,在那之前我们把\((n-2)!\prod_{i=1}^na_i\)看作常数项,只考虑剩余的式子。设
即递推前\(n\)个点的总\(c_i\)为\(m\)的所有情况之和。转移就有:
答案就是\(g_{n,n-2}\times \frac{(n-2)!}{\prod_{i=1}^na_i}\)。
现在就可以有\(20\)分的好成绩了。如果用\(NTT\)实现上面的转移就可以有\(40\)分的好成绩。
然后发现我这个式子并不好优化(懒得优化两个式子)。瞟一眼题解之后发现开始那个式子可以化得好看些:
我们还是不管前面的常数项。可以发现每个\(c_i\)对式子的贡献就是\(\frac{a_i^{c_i}(c_i+1)^m}{c_i!}\)或者一个序列中仅有一个\(c_i\)贡献为\(\frac{a_i^{c_i}(c_i+1)^{2m}}{c_i!}\),那么这个式子就可以由若干个次数表示\(c\)的\(EGF\)乘起来(实际上如果尝试用\(EGF\)推一下那个\(n^3\)递推可以更容易发现这种性质)。乍一看一共有\(n\)个系数不同的\(n\)次多项式,似乎不可做,但是第\(i\)个多项式每一项都有\(a_i\)的若干次方,且与\(x\)次数相同,所以这些多项式都可以写成\(F(a_ix)\)的形式。因此设
答案就是\(\sum_{i=1}^n\frac{A(a_ix)}{B(a_ix)}\prod_{j=1}^nB(a_jx)\)。这样有什么好处呢?这里补充一下这种trick。
如果式子可以写成\(\sum_{i=1}^nF(a_ix)\)的形式,并且对任意\(m\)都求出了\(\sum_{i=1}^na_i^m\),那么只要求出\(F(x)\),式子就可以变成
因此我们求出\(\frac{A(x)}{B(x)}\)和\(\prod_{i=1}^nB(a_ix)\)即可。但后面这个是\(\prod\),和前面的\(\sum\)不同,这里就要取个\(\ln\):
那么我们算一下\(\ln{B(x)}\)就可以像上面那样算了。
现在唯一的问题就是怎么对每个\(m\)求出\(\sum_{i=1}^na_i^m\)。类似于自然数幂和的推导,我们写出这个东西的\(OGF\)就有
这就有点像P4705玩游戏这题的技巧,因为
那设\(H(x)=\sum_{i=1}^n(\ln(1-a_ix))'\),那\(G(x)=-xH(x)+n\)。求\(H\)就:
分治\(NTT\)即可。至此这题就解决了,复杂度瓶颈为最后分治\(NTT\)的\(\mathcal{O}(n\log^2n)\)。
#include<bits/stdc++.h>
#define rg register
#define il inline
#define cn const
#define gc getchar()
#define fp(i, a, b) for(int i = (a), ed = (b); i <= ed; ++i)
#define fb(i, a, b) for(int i = (a), ed = (b); i >= ed; --i)
#define go(u) for(int i = head[u]; ~i; i = e[i].nxt)
using namespace std;
typedef cn int cint;
typedef long long LL;
il void rd(int &x){
x = 0;
rg int f(1); rg char c(gc);
while(c < '0' || '9' < c){if(c == '-')f = -1; c = gc;}
while('0' <= c && c <= '9')x = (x<<1)+(x<<3)+(c^48), c = gc;
x *= f;
}
cint maxn = 30010, mod = 998244353, G = 3, invG = (mod+1)/3;
int n, m, a[maxn], fac[maxn], ifac[maxn], inv[maxn], mul = 1;
int lim, hst, rev[maxn<<2], A[maxn<<2], B[maxn<<2], ln[maxn<<2], invB[maxn<<2];
int H[maxn<<2], E[maxn<<2];
il int fpow(int a, int b, int ans = 1){
for(; b; b >>= 1, a = 1ll*a*a%mod)if(b&1)ans = 1ll*ans*a%mod;
return ans;
}
il void ntt(int *a, cint &typ){
fp(i, 0, lim-1)if(i > rev[i])swap(a[i], a[rev[i]]);
for(rg int md = 1; md < lim; md <<= 1){
rg int len = md<<1, Gn = fpow(typ ? invG : G, (mod-1)/len);
for(rg int l = 0; l < lim; l += len){
for(rg int nw = 0, Pow = 1; nw < md; ++nw, Pow = 1ll*Pow*Gn%mod){
rg int x = a[l+nw], y = 1ll*a[l+nw+md]*Pow%mod;
a[l+nw] = (x+y)%mod, a[l+nw+md] = (x-y+mod)%mod;
}
}
}
if(typ){
rg int inv = fpow(lim, mod-2);
fp(i, 0, lim-1)a[i] = 1ll*a[i]*inv%mod;
}
}
il void init(int n){
lim = 1, hst = 0;
while(lim < n)lim <<= 1, ++hst;
fp(i, 0, lim-1)rev[i] = (rev[i>>1]>>1)|((i&1)<<hst-1);
}
int inv_ary[maxn<<2];
void get_inv(int *a, int *f, int n){
if(n == 1)return f[0] = fpow(a[0], mod-2), void();
get_inv(a, f, n+1>>1), init(2*n-1);
fp(i, 0, n-1)inv_ary[i] = a[i];
ntt(f, 0), ntt(inv_ary, 0);
fp(i, 0, lim-1)f[i] = 1ll*f[i]*(2-1ll*f[i]*inv_ary[i]%mod+mod)%mod;
ntt(f, 1);
fp(i, n, lim-1)f[i] = 0;
fp(i, 0, lim-1)inv_ary[i] = 0;
}
int ln_ary[maxn<<2];
il void get_ln(int *a, int *f, int n){
get_inv(a, ln_ary, n), init(2*n-2);
fp(i, 1, n-1)f[i-1] = 1ll*a[i]*i%mod;
ntt(f, 0), ntt(ln_ary, 0);
fp(i, 0, lim-1)f[i] = 1ll*f[i]*ln_ary[i]%mod;
ntt(f, 1);
fp(i, n-1, lim)f[i] = 0;
fp(i, 0, lim)ln_ary[i] = 0;
fb(i, n-1, 1)f[i] = 1ll*f[i-1]*inv[i]%mod;
f[0] = 0;
}
int exp_ary[maxn<<2];
void get_exp(int *a, int *f, int n){
if(n == 1)return f[0] = 1, void();
get_exp(a, f, n+1>>1), get_ln(f, exp_ary, n), init(2*n-1);
fp(i, 0, n-1)exp_ary[i] = (a[i]-exp_ary[i]+mod)%mod;
if((++exp_ary[0]) == mod)exp_ary[0] = 0;
ntt(f, 0), ntt(exp_ary, 0);
fp(i, 0, lim-1)f[i] = 1ll*f[i]*exp_ary[i]%mod;
ntt(f, 1);
fp(i, n, lim-1)f[i] = 0;
fp(i, 0, lim-1)exp_ary[i] = 0;
}
int div_ary[19][maxn<<2];
void divntt(int d, int l, int r){
if(l == r)return div_ary[d][0] = 1, div_ary[d][1] = mod-a[l], void();
int md = l+r>>1, len = r-l+2;
divntt(d, l, md), divntt(d+1, md+1, r), init(len), ntt(div_ary[d], 0), ntt(div_ary[d+1], 0);
fp(i, 0, lim-1)div_ary[d][i] = 1ll*div_ary[d][i]*div_ary[d+1][i]%mod;
ntt(div_ary[d], 1);
fp(i, len, lim-1)div_ary[d][i] = 0;
fp(i, 0, lim-1)div_ary[d+1][i] = 0;
}
int main(){
// freopen("in", "r", stdin);
rd(n), rd(m);
fp(i, 1, n)rd(a[i]), mul = 1ll*mul*a[i]%mod;
fac[0] = 1; fp(i, 1, n)fac[i] = 1ll*fac[i-1]*i%mod;
ifac[n] = fpow(fac[n], mod-2); fb(i, n, 1)ifac[i-1] = 1ll*ifac[i]*i%mod;
inv[1] = 1; fp(i, 2, n)inv[i] = mod-1ll*(mod/i)*inv[mod%i]%mod;
fp(i, 0, n)A[i] = 1ll*fpow(i+1, m<<1)*ifac[i]%mod;
fp(i, 0, n)B[i] = 1ll*fpow(i+1, m)*ifac[i]%mod;
get_ln(B, ln, n+1), get_inv(B, invB, n+1), init(2*n+1);
// fp(i, 0, n)printf("%d ", B[i]);puts("");
// fp(i, 0, n)printf("%d ", ln[i]);puts("");
// fp(i, 0, n)printf("%d ", invB[i]);puts("");
// fp(i, 0, n)printf("%d ", A[i]);puts("");
// fp(i, 0, n)printf("%d ", invB[i]);puts("");
ntt(invB, 0), ntt(A, 0);
fp(i, 0, lim-1)A[i] = 1ll*A[i]*invB[i]%mod;
ntt(A, 1);
fp(i, n+1, lim)A[i] = 0;
// fp(i, 0, n)printf("%d ", A[i]);puts("");
divntt(0, 1, n), get_ln(div_ary[0], H, n+1);
// fp(i, 0, n)printf("%d ", div_ary[0][i]);puts("");
fp(i, 1, n)H[i] = mod-1ll*H[i]*i%mod;
H[0] = n;
fp(i, 0, n)ln[i] = 1ll*ln[i]*H[i]%mod;
get_exp(ln, E, n+1);
fp(i, 0, n)A[i] = 1ll*A[i]*H[i]%mod;
init(2*n+1), ntt(E, 0), ntt(A, 0);
fp(i, 0, lim-1)A[i] = 1ll*A[i]*E[i]%mod;
ntt(A, 1), printf("%lld\n", 1ll*fac[n-2]*mul%mod*A[n-2]%mod);
return 0;
}