[CF438E]The Child and Binary Tree

  传送门

题目大意

  给一个大小为\(n\)的集合\(S\)\(\forall x \in S:x\)为正整数且\(x\leq 10^5\)。给定\(m\),且对于任意正整数\(s\leq m\),求所有不同的合法的二叉树的方案数满足:其每个结点的权值和为\(s\)。方案不同当且仅当树的形态不同或者存在对应结点中权值不同。答案对\(998244353\)取模。

  具体样例看原题。数据范围:\(n,m\leq 10^5\)

题解

  \([state]\)是单位函数,表示当\(state\)为真时,值为\(1\),否则为\(0\)

  设\(f_i\)表示总和为\(i\)的二叉树的合法方案数。首先很好列出\(\text{DP}\)方程:

\[f_n=\begin{cases}\begin{aligned}\sum_{i=0}^n[i\in S]\sum_{j=0}^{n-i}f_jf_{n-i-j}\end{aligned}&n>0\\1&n=0\end{cases} \]

  可是数据范围太大了?我们用生成函数。

  定义生成函数\(G(x)=\sum\limits_{i\geq 0}[i\in S]x^i\),以及生成函数\(F(x)=\sum\limits_{i\geq 0}f_ix^i\),根据\(\text{DP}\)式子,我们可以得到下面的式子:

\[\begin{aligned} F(x)&=\sum_{i\geq 0}x^i\left([i==0]+\sum_{i=0}^n[i\in S]\sum_{j=0}^{n-i}f_jf_{n-i-j}\right)\\ &=1+\sum_{i\geq 0}\left(\sum_{i=0}^n[i\in S]x^i\sum_{j=0}^{n-i}f_jx^j\cdot f_{n-i-j}x^{n-i-j}\right)\\ &=1+G(x)F^2(x) \end{aligned} \]

  解得\(F=\dfrac{2}{1\pm \sqrt{1-4G}}\)。代入\(f_0=1\)\(g_0=0\),我们选择\(F=\dfrac{2}{1+\sqrt{1-4G}}\)。然后多项式操作一下就结束了。

  复杂度:\(\text{O}(n\log n)\)

代码

#include <bits/stdc++.h>

using namespace std;

const int maxn = 100000 + 5;
const int P = 998244353, g = 3;

int inc(int a, int b) { return (a += b) >= P ? a-P : a; }
int qpow(int a, int b) {
	int res = 1;
	for (int i = a; b; i = 1ll*i*i%P, b >>= 1)
		if (b & 1) res = 1ll*res*i%P;
	return res;
}

int W[maxn << 3], inv[maxn << 2], fac[maxn << 2], ifac[maxn << 2];
void prework(int n) {
	for (int i = 1; i < n; i <<= 1) {
		W[i] = 1;
		for (int j = 1, Wn = qpow(g, (P-1)/i>>1); j < i; j++) W[i+j] = 1ll*W[i+j-1]*Wn%P;
	}
	inv[1] = fac[0] = ifac[0] = 1;
	for (int i = 2; i < n; i++) inv[i] = 1ll*(P-P/i)*inv[P%i]%P;
	for (int i = 1; i < n; i++) fac[i] = 1ll*fac[i-1]*i%P, ifac[i] = 1ll*ifac[i-1]*inv[i]%P;
}

void ntt(int *a, int n, int opt) {
	static int rev[maxn << 2] = {0};
	for (int i = 1; i < n; i++)
		if ((rev[i] = rev[i>>1]>>1|(i&1?n>>1:0)) > i) swap(a[i], a[rev[i]]);
	for (int q = 1; q < n; q <<= 1)
		for (int p = 0; p < n; p += q << 1)
			for (int i = 0, t; i < q; i++)
				t = 1ll*W[q+i]*a[p+q+i]%P, a[p+q+i] = inc(a[p+i], P-t), a[p+i] = inc(a[p+i], t);
	if (~opt) return;
	for (int i = 0; i < n; i++) a[i] = 1ll*a[i]*inv[n]%P;
	reverse(a+1, a+n);
}

struct poly {
	vector<int> A;
	int len;
	poly(int a0 = 0) : len(1) { A.push_back(a0); }
	int &operator [] (int i) { return A[i]; }
	void write() {
		for (int i = 0; i < len; i++) printf("%d ", A[i]);
		putchar('\n');
	}
	void load(int *from, int n) {
		A.resize(len = n);
		memcpy(&A[0], from, sizeof(int) * len);
	}
	void cpyto(int *to, int n) {
		memcpy(to, &A[0], sizeof(int) * min(n, len));
	}
	void resize(int n = 0) {
		if (!n) { while (len > 1 && !A[len - 1]) len--; A.resize(len); } else A.resize(len = n, 0);
	}
} F, G;

poly poly_inv(poly A) {
	poly B = poly(qpow(A[0], P-2));
	for (int len = 1; len < A.len; len <<= 1) {
		static int x[maxn << 2], y[maxn << 2];
		for (int i = 0; i < len << 2; i++) x[i] = y[i] = 0;
		A.cpyto(x, len << 1), B.cpyto(y, len);
		ntt(x, len << 2, 1), ntt(y, len << 2, 1);
		for (int i = 0; i < len << 2; i++) x[i] = inc(y[i], inc(y[i], P-1ll*x[i]*y[i]%P*y[i]%P));
		ntt(x, len << 2, -1); B.load(x, len << 1);
	}
	return B.resize(A.len), B;
}

poly poly_sqrt(poly A) {
	poly B = poly(1); int inv2 = P+1>>1;
	for (int len = 1; len < A.len; len <<= 1) {
		static int x[maxn << 2], y[maxn << 2], z[maxn << 2];
		for (int i = 0; i < len << 2; i++) x[i] = y[i] = z[i] = 0;
		A.cpyto(x, len << 1), B.cpyto(y, len); B.resize(len << 1); poly_inv(B).cpyto(z, len << 1);
		ntt(x, len << 2, 1), ntt(y, len << 2, 1), ntt(z, len << 2, 1);
		for (int i = 0; i < len << 2; i++) x[i] = (x[i]+1ll*y[i]*y[i])%P*z[i]%P*inv2%P;
		ntt(x, len << 2, -1);
		B.load(x, len << 1);
	}
	return B.resize(A.len), B;
}

poly operator + (poly A, poly B) {
	if (A.len < B.len) A.resize(B.len);
	for (int i = 0; i < B.len; i++) A[i] = inc(A[i], B[i]);
	return A.resize(), A;
}

poly operator - (poly A, poly B) {
	if (A.len < B.len) A.resize(B.len);
	for (int i = 0; i < B.len; i++) A[i] = inc(A[i], P-B[i]);
	return A.resize(), A;
}

int getsize(int n) { int N = 1; while (N < n) N <<= 1; return N; }

poly operator * (int k, poly A) {
	for (int i = 0; i < A.len; i++) A[i] = 1ll*k*A[i]%P;
	return A;
}

poly operator * (poly A, poly B) {
	static int x[maxn << 2], y[maxn << 2];
	int len = getsize(A.len + B.len - 1);
	for (int i = 0; i < len; i++) x[i] = y[i] = 0;
	A.cpyto(x, A.len), B.cpyto(y, B.len);
	ntt(x, len, 1), ntt(y, len, 1);
	for (int i = 0; i < len; i++) x[i] = 1ll*x[i]*y[i]%P;
	ntt(x, len, -1);
	return A.load(x, A.len + B.len - 1), A.resize(), A;
}

poly poly_deri(poly A) {
	for (int i = 0; i < A.len - 1; i++) A[i] = 1ll*A[i+1]*(i+1)%P;
	return A[A.len - 1] = 0, A.resize(), A;
}

poly poly_int(poly A) {
	for (int i = A.len - 1; i; i--) A[i] = 1ll*A[i-1]*inv[i]%P;
	return A[0] = 0, A;
}

poly poly_ln(poly A) {
	poly B = poly_deri(A) * poly_inv(A);
	return B.resize(A.len), poly_int(B);
}

poly poly_exp(poly A) {
	poly B = poly(1), C;
	for (int len = 1; len < A.len; len <<= 1) {
		B.resize(len << 1);
		C = poly(1) - poly_ln(B) + A; C.resize(len << 1);
		B = B * C;
	}
	return B.resize(A.len), B;
}

int n, m;

int main() {
	scanf("%d%d", &n, &m); m++; prework(m << 2);
	
	G.resize(m);
	for (int i = 1; i <= n; i++) {
		int c; scanf("%d", &c);
		if (c < m) G[c] = 1;
	}
	
	G = 1-4*G; G.resize(m);
	G = 1+poly_sqrt(G); G.resize(m);
	F = 2*poly_inv(G); F.resize(m);
	for (int i = 1; i < m; i++) printf("%d\n", F[i]);
	
	return 0;
}
posted @ 2019-12-24 14:44  AC-Evil  阅读(158)  评论(0编辑  收藏  举报