[2022牛客多校赛第四场] C-Easy Counting Problem
题目大意
统计长度为\(n\)且数位\(i\)出现至少\(c_i\)次的数字串数量。
\(i\in[0,w)\) \((2\leq w\leq 10)\)
\(1\leq c_i\leq 50000,\sum c_i\leq 50000\)
\(q (1\leq q\leq 300)\) 次询问,每次询问 \(n (1\leq n\leq 10^7)\)
题解
若 \(i\) 恰好出现 \(c_i\) 次,且 \(n=\sum c_i\),则方案数为 \(\frac{n!}{\prod_i c_i !}\)。
考虑指数型生成函数,
数位 \(i\) 至少出现 \(c_i\) 次,所以需要减去次数小于 \(c_i\) 的项,即 \(\exp(x)-\sum_{i=0}^{c_i-1}\frac{x^i}{i!}\)
则最终答案为
但是 \(n\) 太大了,不好直接卷积出来。
令 \(g_k(x)=\sum_{i=0}^{c_k-1}\frac{x^i}{i!}\)
考虑一下这个式子
发现 \(\prod_{k=1}^{w}\left(\exp(x)-g_k(x)\right)\) 是 \(\sum_{k=0}^{w} \exp(kx)f_k(x)\) 的形式。
于是我们想要求出 \(f_0(x),f_1(x),\cdots,f_w(x)\)
设 \(dp(i,j)\) 只考虑前 \(i\) 个 \((\exp(x)-g(x))\) 时 \(f_j(x)\) 的值,则有
设 \(s=\sum c_i\),于是只要进行 \(w^2\) 次卷积就能计算出 \(f_0(x),f_1(x),\cdots,f_w(x)\),时间复杂度 \(O(w^2s\log s)\)。
对于单次询问,给定 \(n\),我们需要计算出 \(n!\sum_{k=0}^{w} \exp(kx)f_k(x)\) 的 \(n\) 次项系数即为答案。对于每个 \(\exp(kx)f_k(x)\),直接暴力卷积 \(n\) 次项,因为 \(f_k(x)\) 最高不超过 \(s\) 次项,所以时间复杂度为 \(O(s)\),每次询问要暴力求 \(w\) 个卷积,复杂度 \(O(ws)\),\(q\) 次询问,复杂度 \(O(qws)\)。还需要 \(O(n)\) 预处理阶乘及其逆元,综上,本题的时间复杂度为 \(O(n+w^2s\log s+qws)\)。
Code
#include <bits/stdc++.h>
using namespace std;
#define LL long long
template<typename elemType>
inline void Read(elemType& T) {
elemType X = 0, w = 0; char ch = 0;
while (!isdigit(ch)) { w |= ch == '-';ch = getchar(); }
while (isdigit(ch)) X = (X << 3) + (X << 1) + (ch ^ 48), ch = getchar();
T = (w ? -X : X);
}
const LL MOD = 998244353;
const int maxn = 1e7 + 5;
LL qpow(LL b, LL n, LL MOD) {
if (MOD == 1) return 0;
LL x = 1, Power = b % MOD;
while (n) {
if (n & 1) x = x * Power % MOD;
Power = Power * Power % MOD;
n >>= 1;
}
return x;
}
namespace Poly {
const int maxn = 2100000;
int r[maxn];
int L, limit;
const LL P = 998244353, G = 3, Gi = 332748118;
LL pinv(LL x) { return qpow(x, P - 2, P); }
void NTT(LL* A, int type) {
for (int i = 0; i < limit; i++)
if (i < r[i]) swap(A[i], A[r[i]]);
for (int mid = 1; mid < limit; mid <<= 1) {
LL Wn = qpow(type == 1 ? G : Gi, (P - 1) / (mid << 1), P);
for (int j = 0; j < limit; j += (mid << 1)) {
LL w = 1;
for (int k = 0; k < mid; k++, w = (w * Wn) % P) {
int x = A[j + k], y = w * A[j + k + mid] % P;
A[j + k] = (x + y) % P;
A[j + k + mid] = (x - y + P) % P;
}
}
}
if (type == 1) return;
LL inv_limit = pinv(limit);
for (int i = 0; i < limit; ++i)
A[i] = A[i] * inv_limit % P;
}
void Conv(LL* a, int N, LL* b, LL M, LL* c) {
L = 0; limit = 1;
while (limit <= N + M) limit <<= 1, L++;
for (int i = N;i < limit;++i) a[i] = 0;
for (int i = M;i < limit;++i) b[i] = 0;
for (int i = 0; i < limit; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));
NTT(a, 1); NTT(b, 1);
for (int i = 0; i < limit; i++) c[i] = a[i] * b[i] % P;
NTT(c, -1);
}
}
LL f[200010], g[200010];
int inv[maxn], fact[maxn], finv[maxn], c[11];
vector<LL> dp[2][11];
int q, w;
void init() {
inv[1] = fact[0] = fact[1] = finv[0] = finv[1] = 1;
for (int i = 2;i <= 1e7;++i) {
inv[i] = ((-1LL * (MOD / i) * inv[MOD % i]) % MOD + MOD) % MOD;
fact[i] = 1LL * fact[i - 1] * i % MOD;
finv[i] = 1LL * finv[i - 1] * inv[i] % MOD;
}
}
void add(const vector<LL>& a, const vector<LL>& b, vector<LL>& c) {
c.clear(); c.resize(max(a.size(), b.size()));
int p = 0;
while (p < a.size() && p < b.size()) { c[p] = (a[p] + b[p]) % MOD; ++p; }
while (p < a.size()) { c[p] = a[p]; ++p; }
while (p < b.size()) { c[p] = b[p]; ++p; }
}
void conv(const vector<LL>& a, const vector<LL>& b, vector<LL>& c) {
int n = a.size(), m = b.size();
for (int i = 0;i < n;++i) f[i] = a[i];
for (int i = 0;i < m;++i) g[i] = b[i];
Poly::Conv(f, n, g, m, f);
c.clear(); c.resize(n + m - 1);
for (int i = 0;i < n + m - 1;++i) c[i] = f[i];
}
LL solve(int n) {
LL ans = 0;
for (int k = 0;k <= w;++k) {
LL k_inv = qpow(k, MOD - 2, MOD);
LL kk = qpow(k, n, MOD);
for (int i = 0;i < dp[w & 1][k].size() && i <= n;++i) {
ans = (ans + dp[w & 1][k][i] * finv[n - i] % MOD * kk % MOD) % MOD;
kk = kk * k_inv % MOD;
}
}
ans = ans * fact[n] % MOD;
return ans;
}
vector<LL> vec;
int main() {
init();
Read(w);
for (int i = 1;i <= w;++i)
Read(c[i]);
dp[0][0].push_back(1);
for (int i = 1;i <= w;++i) {
for (int j = 0;j <= i;++j) {
dp[i & 1][j].clear();
vec.clear(); vec.resize(c[i]);
for (int k = 0;k < c[i];++k)
vec[k] = MOD - finv[k];
conv(dp[(i & 1) ^ 1][j], vec, dp[i & 1][j]);
if (j == 0) continue;
add(dp[i & 1][j], dp[(i & 1) ^ 1][j - 1], vec);
dp[i & 1][j] = vec;
}
}
int s = 0;
for (int i = 1;i <= w;++i) s += c[i];
Read(q);
while (q--) {
int n; Read(n);
if (n < s) printf("0\n");
else printf("%lld\n", solve(n));
}
return 0;
}