CF1580F Problems for Codeforces 【生成函数,组合计数】

给定正整数 \(n,m\),求有多少个正整数序列 \(a_1,\cdots,a_n\) 使得 \(a_i+a_{i+1}<m\)\(a_1+a_n<m\),答案对 \(998\,244\,353\) 取模。

\(n\le 5\cdot 10^4\)\(m\le 10^9\)


先看 \(n\) 是偶数的情况:当 \(i\) 为奇数时把 \(a_i\) 改为 \(m-1-a_i\),条件变为 \(a_1\le a_2\ge\cdots\le a_n\ge a_1\),对 \(\ge\) 进行一个容斥,计算选择一些 \(\ge\) 变为 \(<\),其他忽略掉的方案数。长为 \(2l\) 的段的填数方案数为 \(\dbinom{m+l}{2l}\),设生成函数 \(\displaystyle F(z)=\sum_{l\ge 1}\binom{m+l}{2l}z^l\),过程就是按顺序确定好段长,然后确定 \(a_1\) 在段上的位置,从而答案即为 \((-1)^{n/2-1}[z^{n/2-1}]\dfrac{F'(z)}{1+F(z)}\)

然后就 \(n\) 是奇数的情况:设 \(m_0=\lfloor m/2\rfloor\)\(m_1=\lceil m/2\rceil\),考虑 \([a_i\ge m_1]\),不能有相邻两个都为 \(1\),则整个环被划分为了若干段,每段形如 \(010\cdots 010\)。将不小于 \(m_1\)\(a_i\) 减去 \(m_1\),此时 \(0\le a_i\le m_0\),若存在 \(a_i=m_0\),则 \(m\) 为奇数且 \(a_i\) 的相邻两项都小于 \(m_1\),从而只会在长为 \(1\) 的段上出现,特判一下。

现在要求计算 \(a_i\in[0,m_0)\) 使得 \(a_0\le a_1\ge\cdots\le a_{2l-1}\ge a_{2l}\) 的方案数,仍然对 \(\ge\) 容斥,设 \(\displaystyle G(z)=\sum_{l\ge 0}\binom{m_0+l}{2l+1}\),则对应方案数的生成函数是 \(P(z)=\dfrac{G(-z)}{1+F(-z)}+[2\nmid m]z\),同样拼一拼,设 \(P(z)\) 逐位乘上 \(2i+1\) 得到 \(Q(z)\),则答案为 \([z^{(n-1)/2}]\dfrac{Q(z)}{1-zP^2(z)}\)

#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1 << 16, mod = 998244353;
int ksm(int a, int b){
	int res = 1;
	for(;b;b >>= 1, a = (LL)a * a % mod)
		if(b & 1) res = (LL)res * a % mod;
	return res;
}
int n, m, w[N], rev[N], lim, A[N], B[N], C[N], iv[N], res;
void calrev(int len){
	int L = -1; lim = 1;
	while(lim <= len){lim <<= 1; ++ L;}
	for(int i = 1;i < lim;++ i)
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << L);
}
void NTT(int *A, bool op){
	for(int i = 0;i < lim;++ i)
		if(i < rev[i]) swap(A[i], A[rev[i]]);
	for(int md = 1;md < lim;md <<= 1)
		for(int i = 0;i < lim;i += md << 1)
			for(int j = 0;j < md;++ j){
				int y = LL(op && j ? mod - w[(md << 1) - j] : w[md + j]) * A[md + i + j] % mod;
				if((A[md + i + j] = A[i + j] - y) < 0) A[md + i + j] += mod;
				if((A[i + j] += y) >= mod) A[i + j] -= mod;
			}
	if(op){
		int inv = ksm(lim, mod - 2);
		for(int i = 0;i < lim;++ i) A[i] = (LL)A[i] * inv % mod;
	}
}
int ans[N], tmp[N];
void getinv(int *A, int d){
	if(d == 1){ans[0] = ksm(A[0], mod - 2); return;}
	getinv(A, (d + 1) >> 1); calrev(d << 1);
	memcpy(tmp, A, d << 2); memset(tmp + d, 0, (lim - d) << 2);
	NTT(tmp, 0); NTT(ans, 0);
	for(int i = 0;i < lim;++ i)
		ans[i] = (mod + 2ll - (LL)ans[i] * tmp[i] % mod) * ans[i] % mod;
	NTT(ans, 1); memset(ans + d, 0, (lim - d) << 2);
}
int main(){
	ios::sync_with_stdio(0);
	cin >> n >> m; iv[1] = 1;
	for(int i = 2;i < N;++ i) iv[i] = mod - (LL)mod / i * iv[mod % i] % mod;
	for(int md = 1;md < N;md <<= 1){
		int Wn = ksm(3, (mod - 1) / (md << 1)); w[md] = 1;
		for(int i = 1;i < md;++ i) w[md + i] = (LL)w[md + i - 1] * Wn % mod;
	}
	if(n & 1){
		int m0 = m >> 1; n >>= 1; A[0] = 1; B[0] = m0;
		for(int i = 1;i <= n;++ i)
			A[i] = A[i - 1] * (m0 + mod - i + 1ll) % mod * (m0 + i) % mod * iv[i * 2 - 1] % mod * iv[i * 2] % mod;
		getinv(A, n + 1);
		for(int i = 1;i <= n;++ i)
			B[i] = (LL)B[i - 1] * (m0 + mod - i) % mod * (m0 + i) % mod * iv[i * 2] % mod * iv[i * 2 + 1] % mod;
		
		NTT(ans, 0); NTT(B, 0);
		for(int i = 0;i < lim;++ i) B[i] = (LL)B[i] * ans[i] % mod;
		NTT(B, 1); memset(B + n + 1, 0, (lim - n - 1) << 2);
		for(int i = 1;i <= n;i += 2) B[i] = mod - B[i];
		if(m & 1) ++ B[0];
		for(int i = 0;i <= n;++ i) C[i] = B[i] * (2 * i + 1ll) % mod;
		
		NTT(B, 0);
		for(int i = 0;i < lim;++ i) B[i] = (LL)B[i] * B[i] % mod;
		NTT(B, 1); memset(B + n + 1, 0, (lim - n - 1) << 2);
		for(int i = n;i;-- i) B[i] = mod - B[i - 1];
		B[0] = 1; memset(ans, 0, lim << 2); getinv(B, n + 1);
		for(int i = 0;i <= n;++ i) res = (res + (LL)C[i] * ans[n - i]) % mod;
	} else {
		n >>= 1; A[0] = 1;
		for(int i = 1;i <= n;++ i)
			A[i] = A[i - 1] * (m + mod - i + 1ll) % mod * (m + i) % mod * iv[i * 2 - 1] % mod * iv[i * 2] % mod;
		getinv(A, n + 1);
		for(int i = 1;i <= n;++ i) res = (res + (LL)A[i] * i % mod * ans[n - i]) % mod;
		if(!(n & 1) && res) res = mod - res;
	}
	cout << res << '\n';
}
posted @ 2022-07-23 02:33  mizu164  阅读(117)  评论(0编辑  收藏  举报