CF1096G.Lucky Tickets[DP+卷积+多项式求幂] Educational Codeforces Round 57
给定一个偶数长度n和字符集(0..9中的一些数字) 问有多少个串的前 \(\frac{n}{2}\) 位的位数和跟后 \(\frac{n}{2}\) 位相等
\[f\left( i,j \right) \text{表示}i\text{个数的和是}j\text{的方案数}
\\
\text{答案是}\sum_y{f\left( \frac{n}{2},y \right)}^2\text{(前半部分和后半部分都是}\sum_y{f\left( \frac{n}{2},y \right)}\text{)}
\\
\text{所以只需要考虑怎么求}f
\\
f\left( x+\text{1,}y \right) =\sum_{i=0}^9{f\left( x,y-i \right)}
\\
\text{可以把上面的柿子补成卷积形式}
\\
\text{令}H=f\left( x+1 \right) \text{,}F=f\left( x \right) \text{,}H\left( y \right) \text{表示的就是}f\left( x+\text{1,}y \right) \text{,}F\text{同理}
\\
H\left( k \right) =\sum_{i=0}^{\min \left( k,9 \right)}{F\left( i \right)}\ast G\left( k-i \right)
\\
\text{即}f\left( x+1 \right) \left( k \right) =\sum_{i=0}^{\min \left( k,9 \right)}{f\left( x \right) \left( i \right)}\ast G\left( k-i \right)
\\
\text{比较容易发现没有字符集限制的情况下}G=1
\\
\text{有限制的时候}G=\sum_{i=0}^9{a_ix^i}\text{,}a_i\text{表示}i\text{是否在允许的字符集中,}x\text{没有实际意义}
\\
\text{所以要做的就是求出}G\left( x \right) ^{\frac{n}{2}}\text{然后统计每一项系数的平方和}
\]
贴个Tutorial里的代码
#include<bits/stdc++.h>
using namespace std;
const int LOGN = 21;
const int N = (1 << LOGN);
const int MOD = 998244353;
const int g = 3;
#define forn(i, n) for(int i = 0; i < int(n); i++)
inline int mul(int a, int b)
{
return (a * 1ll * b) % MOD;
}
inline int norm(int a)
{
while(a >= MOD)
a -= MOD;
while(a < 0)
a += MOD;
return a;
}
inline int binPow(int a, int k)
{
int ans = 1;
while(k > 0)
{
if(k & 1)
ans = mul(ans, a);
a = mul(a, a);
k >>= 1;
}
return ans;
}
inline int inv(int a)
{
return binPow(a, MOD - 2);
}
vector<int> w[LOGN];
vector<int> iw[LOGN];
vector<int> rv[LOGN];
void precalc()
{
int wb = binPow(g, (MOD - 1) / (1 << LOGN));
for(int st = 0; st < LOGN; st++)
{
w[st].assign(1 << st, 1);
iw[st].assign(1 << st, 1);
int bw = binPow(wb, 1 << (LOGN - st - 1));
int ibw = inv(bw);
int cw = 1;
int icw = 1;
for(int k = 0; k < (1 << st); k++)
{
w[st][k] = cw;
iw[st][k] = icw;
cw = mul(cw, bw);
icw = mul(icw, ibw);
}
rv[st].assign(1 << st, 0);
if(st == 0)
{
rv[st][0] = 0;
continue;
}
int h = (1 << (st - 1));
for(int k = 0; k < (1 << st); k++)
rv[st][k] = (rv[st - 1][k & (h - 1)] << 1) | (k >= h);
}
}
inline void fft(int a[N], int n, int ln, bool inverse)
{
for(int i = 0; i < n; i++)
{
int ni = rv[ln][i];
if(i < ni)
swap(a[i], a[ni]);
}
for(int st = 0; (1 << st) < n; st++)
{
int len = (1 << st);
for(int k = 0; k < n; k += (len << 1))
{
for(int pos = k; pos < k + len; pos++)
{
int l = a[pos];
int r = mul(a[pos + len], (inverse ? iw[st][pos - k] : w[st][pos - k]));
a[pos] = norm(l + r);
a[pos + len] = norm(l - r);
}
}
}
if(inverse)
{
int in = inv(n);
for(int i = 0; i < n; i++)
a[i] = mul(a[i], in);
}
}
int aa[N], bb[N], cc[N];
inline void multiply(int a[N], int sza, int b[N], int szb, int c[N], int &szc)
{
int n = 1, ln = 0;
while(n < (sza + szb))
n <<= 1, ln++;
for(int i = 0; i < n; i++)
aa[i] = (i < sza ? a[i] : 0);
for(int i = 0; i < n; i++)
bb[i] = (i < szb ? b[i] : 0);
fft(aa, n, ln, false);
fft(bb, n, ln, false);
for(int i = 0; i < n; i++)
cc[i] = mul(aa[i], bb[i]);
fft(cc, n, ln, true);
szc = n;
for(int i = 0; i < n; i++)
c[i] = cc[i];
}
vector<int> T[N];
int a[N];
int b[N];
int c[N];
#define sz(a) (int(a.size()))
int main()
{
precalc();
int n, k;
scanf("%d %d", &n, &k);
for(int i = 0; i < k; i++)
{
int x;
scanf("%d", &x);
a[x] = 1;
}
int nn = 1, ln = 0;
int nw = (n * 5) + 1;
while(nn < nw)
{
nn *= 2;
ln++;
}
fft(a, nn, ln, false);
forn(i, nn)
a[i] = binPow(a[i], n / 2);
fft(a, nn, ln, true);
int ans = 0;
forn(i, nn)
ans = norm(ans + binPow(a[i], 2));
printf("%d\n", ans);
return 0;
}