CF1580F Problems for Codeforces 【生成函数,组合计数】
给定正整数 \(n,m\),求有多少个正整数序列 \(a_1,\cdots,a_n\) 使得 \(a_i+a_{i+1}<m\) 且 \(a_1+a_n<m\),答案对 \(998\,244\,353\) 取模。
\(n\le 5\cdot 10^4\),\(m\le 10^9\)。
先看 \(n\) 是偶数的情况:当 \(i\) 为奇数时把 \(a_i\) 改为 \(m-1-a_i\),条件变为 \(a_1\le a_2\ge\cdots\le a_n\ge a_1\),对 \(\ge\) 进行一个容斥,计算选择一些 \(\ge\) 变为 \(<\),其他忽略掉的方案数。长为 \(2l\) 的段的填数方案数为 \(\dbinom{m+l}{2l}\),设生成函数 \(\displaystyle F(z)=\sum_{l\ge 1}\binom{m+l}{2l}z^l\),过程就是按顺序确定好段长,然后确定 \(a_1\) 在段上的位置,从而答案即为 \((-1)^{n/2-1}[z^{n/2-1}]\dfrac{F'(z)}{1+F(z)}\)。
然后就 \(n\) 是奇数的情况:设 \(m_0=\lfloor m/2\rfloor\),\(m_1=\lceil m/2\rceil\),考虑 \([a_i\ge m_1]\),不能有相邻两个都为 \(1\),则整个环被划分为了若干段,每段形如 \(010\cdots 010\)。将不小于 \(m_1\) 的 \(a_i\) 减去 \(m_1\),此时 \(0\le a_i\le m_0\),若存在 \(a_i=m_0\),则 \(m\) 为奇数且 \(a_i\) 的相邻两项都小于 \(m_1\),从而只会在长为 \(1\) 的段上出现,特判一下。
现在要求计算 \(a_i\in[0,m_0)\) 使得 \(a_0\le a_1\ge\cdots\le a_{2l-1}\ge a_{2l}\) 的方案数,仍然对 \(\ge\) 容斥,设 \(\displaystyle G(z)=\sum_{l\ge 0}\binom{m_0+l}{2l+1}\),则对应方案数的生成函数是 \(P(z)=\dfrac{G(-z)}{1+F(-z)}+[2\nmid m]z\),同样拼一拼,设 \(P(z)\) 逐位乘上 \(2i+1\) 得到 \(Q(z)\),则答案为 \([z^{(n-1)/2}]\dfrac{Q(z)}{1-zP^2(z)}\)。
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1 << 16, mod = 998244353;
int ksm(int a, int b){
int res = 1;
for(;b;b >>= 1, a = (LL)a * a % mod)
if(b & 1) res = (LL)res * a % mod;
return res;
}
int n, m, w[N], rev[N], lim, A[N], B[N], C[N], iv[N], res;
void calrev(int len){
int L = -1; lim = 1;
while(lim <= len){lim <<= 1; ++ L;}
for(int i = 1;i < lim;++ i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << L);
}
void NTT(int *A, bool op){
for(int i = 0;i < lim;++ i)
if(i < rev[i]) swap(A[i], A[rev[i]]);
for(int md = 1;md < lim;md <<= 1)
for(int i = 0;i < lim;i += md << 1)
for(int j = 0;j < md;++ j){
int y = LL(op && j ? mod - w[(md << 1) - j] : w[md + j]) * A[md + i + j] % mod;
if((A[md + i + j] = A[i + j] - y) < 0) A[md + i + j] += mod;
if((A[i + j] += y) >= mod) A[i + j] -= mod;
}
if(op){
int inv = ksm(lim, mod - 2);
for(int i = 0;i < lim;++ i) A[i] = (LL)A[i] * inv % mod;
}
}
int ans[N], tmp[N];
void getinv(int *A, int d){
if(d == 1){ans[0] = ksm(A[0], mod - 2); return;}
getinv(A, (d + 1) >> 1); calrev(d << 1);
memcpy(tmp, A, d << 2); memset(tmp + d, 0, (lim - d) << 2);
NTT(tmp, 0); NTT(ans, 0);
for(int i = 0;i < lim;++ i)
ans[i] = (mod + 2ll - (LL)ans[i] * tmp[i] % mod) * ans[i] % mod;
NTT(ans, 1); memset(ans + d, 0, (lim - d) << 2);
}
int main(){
ios::sync_with_stdio(0);
cin >> n >> m; iv[1] = 1;
for(int i = 2;i < N;++ i) iv[i] = mod - (LL)mod / i * iv[mod % i] % mod;
for(int md = 1;md < N;md <<= 1){
int Wn = ksm(3, (mod - 1) / (md << 1)); w[md] = 1;
for(int i = 1;i < md;++ i) w[md + i] = (LL)w[md + i - 1] * Wn % mod;
}
if(n & 1){
int m0 = m >> 1; n >>= 1; A[0] = 1; B[0] = m0;
for(int i = 1;i <= n;++ i)
A[i] = A[i - 1] * (m0 + mod - i + 1ll) % mod * (m0 + i) % mod * iv[i * 2 - 1] % mod * iv[i * 2] % mod;
getinv(A, n + 1);
for(int i = 1;i <= n;++ i)
B[i] = (LL)B[i - 1] * (m0 + mod - i) % mod * (m0 + i) % mod * iv[i * 2] % mod * iv[i * 2 + 1] % mod;
NTT(ans, 0); NTT(B, 0);
for(int i = 0;i < lim;++ i) B[i] = (LL)B[i] * ans[i] % mod;
NTT(B, 1); memset(B + n + 1, 0, (lim - n - 1) << 2);
for(int i = 1;i <= n;i += 2) B[i] = mod - B[i];
if(m & 1) ++ B[0];
for(int i = 0;i <= n;++ i) C[i] = B[i] * (2 * i + 1ll) % mod;
NTT(B, 0);
for(int i = 0;i < lim;++ i) B[i] = (LL)B[i] * B[i] % mod;
NTT(B, 1); memset(B + n + 1, 0, (lim - n - 1) << 2);
for(int i = n;i;-- i) B[i] = mod - B[i - 1];
B[0] = 1; memset(ans, 0, lim << 2); getinv(B, n + 1);
for(int i = 0;i <= n;++ i) res = (res + (LL)C[i] * ans[n - i]) % mod;
} else {
n >>= 1; A[0] = 1;
for(int i = 1;i <= n;++ i)
A[i] = A[i - 1] * (m + mod - i + 1ll) % mod * (m + i) % mod * iv[i * 2 - 1] % mod * iv[i * 2] % mod;
getinv(A, n + 1);
for(int i = 1;i <= n;++ i) res = (res + (LL)A[i] * i % mod * ans[n - i]) % mod;
if(!(n & 1) && res) res = mod - res;
}
cout << res << '\n';
}