[清华集训2017]生成树计数

[清华集训2017]生成树计数

题面

uoj

题解

考虑贡献 \(\mathrm{val}(T) = \left(\prod_{i=1}^{n} {d_i}^m\right)\left(\sum_{i=1}^{n} {d_i}^m\right)\),我们先不管后面 \(\sum_{i=1}^nd_i^m\) 的部分。

然后我们就搬来 prufer 序列联通块生成树理论的那一套,设生成树中每个联通块的度数为\(d_i\),那么贡献可以表示为

\[\sum_{\sum d_i=2n-2,d_i\geq 1}\frac {(n-2)!}{\prod(d_i-1)!}\prod a_i^{d_i}\prod d_i^m \]

其中\(\frac {(n-2)!}{\prod(d_i-1)!}\)是联通块构成的 prufer 序列数,\(\prod a_i^{d_i}\)是每个联通块选择点的方案数,\(\prod d_i^m\)是我们现在只考虑的贡献。
转化一下就是:

\[\prod a_i\sum_{\sum d_i=n-2,d_i\geq 0}\frac {(n-2)!}{\prod d_i!}\prod a_i^{d_i}\prod (d_i+1)^m \]

构造 EGF :

\[F(x)=\sum_{i=0}^{\infty} \frac {x^i}{i!}(i+1)^m \]

那么最后我们的答案就是

\[\prod a_i (n-2)![x^{n-2}]\prod F(a_ix) \]

现在问题就变为了如何求\(\prod F(a_ix)\)


考虑这样一个问题:给定一个\(m\)次多项式\(B(x)=\sum_{i=0}^m b_ix^i\)\(n\)个数\(a_i\),如何求\(\sum B(a_ix)\)

把和写开就是\(\sum B(a_ix)=\sum_i\sum_jb_ja_i^jx^j=\sum_jx^jb_j\sum_ia_i^j\),然后就是对于每个\(j\in[0,m]\)\(a\)的等幂和。

等幂和可以表示为\(\sum_i\frac 1{1-a_ix}\)的每一项,通分后就是\(\frac {\sum_i \prod_{j\neq i} (1-a_jx)}{\prod (1-a_ix)}\)

\(C(x)=\prod (1-a_ix)\),那么\(C\)的系数翻转之后的多项式\(C_R(x)=\prod (-a_i+x)\),求导后\(C_R'(x)=\sum_i1\times\prod _{j\neq i}(-a_j+x)\),最后\(\big (C'_R(x)\big )_R\)就是分子,用分治 FFT 和多项式求逆可以做到\(O(n\log ^2n+m\log m)\)


回到求\(\prod F(a_ix)\)\(\prod F(a_ix)=\exp(\sum\ln F(a_ix))=\exp (\sum(\ln F)(a_ix))\),然后就是上面求的等幂和了。

最后再考虑加上 \(\sum_id_i^m\) 的部分。

发现 \(\mathrm{val}(T) = \left(\prod_{i=1}^{n} {d_i}^m\right)\left(\sum_{i=1}^{n} {d_i}^m\right)\) 就是钦定某个\(i\)的贡献为\(d_i^{2m}\),令\(G(x)=\sum_{i=0}^{\infty} \frac {x^i}{i!}(i+1)^{2m}\),那么答案的生成函数可表示为

\[\left (\prod F(a_ix)\right )\left (\sum \frac {G(a_ix)}{F(a_ix)}\right) \]

求出\(H(x)=\frac {G(x)}{F(x)}\)后再做一遍等幂和即可。

最后复杂度是\(O(n\log ^2n+n\log m)\),复杂度与所给\(m\)基本无关,但是常数的话你懂的。。

代码

#include <bits/stdc++.h> 
using namespace std; 
int gi() { 
	int res = 0, w = 1; 
	char ch = getchar(); 
	while (ch != '-' && !isdigit(ch)) ch = getchar(); 
	if (ch == '-') w = -1, ch = getchar(); 
	while (isdigit(ch)) res = res * 10 + ch - '0', ch = getchar(); 
	return res * w; 
}
const int Mod = 998244353; 
int fpow(int x, int y) { 
	int res = 1;
	while (y) {
		if (y & 1) res = 1ll * res * x % Mod; 
		x = 1ll * x * x % Mod, y >>= 1; 
	}
	return res; 
} 
const int MAX_N = 2e5 + 5; 
int fac[MAX_N], ifc[MAX_N]; 
int Limit, rev[MAX_N], omg[MAX_N], inv[MAX_N];
#define VI vector<int> 
void FFT_prepare(int len) {
	int p = 0; 
	for (Limit = 1; Limit <= len; Limit <<= 1) p++; 
	for (int i = 1; i < Limit; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (p - 1)); 
	omg[0] = 1, omg[1] = fpow(3, (Mod - 1) / Limit); 
	for (int i = 2; i < Limit; i++) omg[i] = 1ll * omg[i - 1] * omg[1] % Mod; 
} 
void NTT(VI &p, int op) { 
	p.resize(Limit); 
	for (int i = 1; i < Limit; i++) if (i < rev[i]) swap(p[i], p[rev[i]]); 
	for (int i = 1, t = Limit >> 1; i < Limit; i <<= 1, t >>= 1)
		for (int j = 0; j < Limit; j += i << 1)
			for (int k = 0, o = 0; k < i; k++, o += t) { 
				int x = p[j + k], y = 1ll * omg[o] * p[i + j + k] % Mod; 
				p[j + k] = (x + y) % Mod, p[i + j + k] = (x - y + Mod) % Mod; 
			} 
	if (!op) { 
		reverse(p.begin() + 1, p.end()); 
		for (int i = 0; i < Limit; i++) p[i] = 1ll * p[i] * inv[Limit] % Mod; 
	} 
} 
void Poly_Der(int n, VI &a, VI &b) { 
	b.resize(n - 1); 
	for (int i = 0; i < n - 1; i++) b[i] = 1ll * a[i + 1] * (i + 1) % Mod; 
} 
void Poly_Int(int n, VI &a, VI &b) { 
	b.resize(n + 1), b[0] = 0; 
	for (int i = 1; i <= n; i++) b[i] = 1ll * a[i - 1] * inv[i] % Mod; 
} 
void Poly_Inv(int n, VI a, VI &b) { 
	if (n == 1) return b.clear(), b.push_back(fpow(a[0], Mod - 2)); 
	Poly_Inv((n + 1) >> 1, a, b);
	a.resize(n), b.resize(n);
	VI d = b; 
	FFT_prepare(1.5 * n + 0.5); 
	NTT(a, 1), NTT(d, 1); 
	for (int i = 0; i < Limit; i++) a[i] = 1ll * a[i] * d[i] % Mod * d[i] % Mod; 
	NTT(a, 0); 
	for (int i = (n + 1) >> 1; i < n; i++) b[i] = Mod - a[i];
	d.clear();
} 

void Poly_Ln(int n, VI a, VI &b) { 
	VI c, d; Poly_Inv(n, a, c), Poly_Der(n, a, d); 
	FFT_prepare(n + n); 
	NTT(c, 1), NTT(d, 1); 
	for (int i = 0; i < Limit; i++) c[i] = 1ll * c[i] * d[i] % Mod; 
	NTT(c, 0); 
	Poly_Int(n, c, b); 
} 
void Poly_Exp(int n, VI a, VI &b) { 
	if (n == 1) return b.clear(), b.push_back(1); 
	Poly_Exp((n + 1) >> 1, a, b), b.resize(n); 
	VI c, d = b; Poly_Ln(n, b, c); 
	FFT_prepare(n + 1); 
	for (int i = 0; i < n; i++) c[i] = (a[i] - c[i] + (i == 0) + Mod) % Mod; 
	NTT(c, 1), NTT(d, 1); 
	for (int i = 0; i < Limit; i++) c[i] = 1ll * c[i] * d[i] % Mod; 
	NTT(c, 0); 
	for (int i = (n + 1) >> 1; i < n; i++) b[i] = c[i]; 
} 
int N = 2e5, M; 
int a[MAX_N];
VI Div(int l, int r) { 
	if (l == r) return {1, Mod - a[l]}; 
	int mid = (l + r) >> 1; 
	VI L = Div(l, mid), R = Div(mid + 1, r); 
	int len = L.size() + R.size() - 1; 
	FFT_prepare(len); 
	NTT(L, 1), NTT(R, 1); 
	for (int i = 0; i < Limit; i++) L[i] = 1ll * L[i] * R[i] % Mod; 
	NTT(L, 0), L.resize(len); 
	return L; 
} 
VI F, G, H, iF, A, B, C, CR, dCR, iC, LnF, pF, ans; 

int main () { 
#ifndef ONLINE_JUDGE 
    freopen("cpp.in", "r", stdin);
#endif 
	for (int i = fac[0] = 1; i <= N; i++) fac[i] = 1ll * fac[i - 1] * i % Mod; 
	ifc[N] = fpow(fac[N], Mod - 2); 
	for (int i = N - 1; ~i; i--) ifc[i] = 1ll * ifc[i + 1] * (i + 1) % Mod; 
	for (int i = inv[0] = 1; i <= N; i++) inv[i] = 1ll * ifc[i] * fac[i - 1] % Mod; 
	N = gi(), M = gi();
	if (N == 1) return puts(M == 0 ? "1" : "0") & 0; 
	for (int i = 1; i <= N; i++) a[i] = gi(); 
	//prepare F, G
	F.resize(N), G.resize(N); 
	for (int i = 0; i < N; i++) F[i] = 1ll * ifc[i] * fpow(i + 1, M) % Mod; 
	for (int i = 0; i < N; i++) G[i] = 1ll * ifc[i] * fpow(i + 1, 2 * M) % Mod;
	//prepare iF
	Poly_Inv(N, F, iF);
	//prepare H
	FFT_prepare(N << 1); 
	A = iF, B = G; 
	NTT(A, 1), NTT(B, 1), H.resize(Limit); 
	for (int i = 0; i < Limit; i++) H[i] = 1ll * A[i] * B[i] % Mod; 
	NTT(H, 0);
	//prepare C = sigma 1 / (1 - a[i]x)
	C = Div(1, N); CR = C; reverse(CR.begin(), CR.end()); 
	Poly_Der(CR.size(), CR, dCR); 
	reverse(dCR.begin(), dCR.end()); 
	Poly_Inv(C.size(), C, iC); 
	FFT_prepare(iC.size() + dCR.size()); 
	NTT(iC, 1), NTT(dCR, 1), C.resize(Limit); 
	for (int i = 0; i < Limit; i++) C[i] = 1ll * iC[i] * dCR[i] % Mod; 
	NTT(C, 0); 
	//prepare prod F(a[i]x)
	Poly_Ln(N, F, LnF);
	for (int i = 0; i < N; i++) LnF[i] = 1ll * LnF[i] * C[i] % Mod; 
	Poly_Exp(N, LnF, pF); 
	//prepare sigma H(a[i]x)
	for (int i = 0; i < N; i++) H[i] = 1ll * H[i] * C[i] % Mod; 
	//getans
	FFT_prepare(N << 1); 
	NTT(pF, 1), NTT(H, 1), ans.resize(Limit); 
	for (int i = 0; i < Limit; i++) ans[i] = 1ll * pF[i] * H[i] % Mod; 
	NTT(ans, 0); 
	int pa = 1;
	for (int i = 1; i <= N; i++) pa = 1ll * pa * a[i] % Mod; 
	printf("%lld\n", 1ll * ans[N - 2] * fac[N - 2] % Mod * pa % Mod); 
    return 0; 
} 
posted @ 2020-10-27 17:32  heyujun  阅读(234)  评论(0编辑  收藏  举报