AtCoder Beginner Contest 241 Ex Card Deck Score
答案即为:
考虑生成函数,设 \(F_i(x) = \sum\limits_{j = 0}^{b_i} (a_i x)^j\)。那么答案即为 \([x^m] \prod\limits_{i = 1}^n F_i(x)\)。
考虑 \(F_i(x) = \sum\limits_{j = 0}^{b_i} (a_i x)^j = \frac{1 - (a_i x)^{b_i + 1}}{1 - a_i x}\)。分子可以 \(O(2^n)\) 枚举乘了哪一项,这样可以知道分母需要贡献 \(x\) 的多少次方。我们重点关注分母。设现在要求分母的 \(x^q\) 次项。
有一个常见套路是设 \(\sum\limits_{i = 1}^n \frac{f_i}{1 - a_i x} = \frac{1}{\prod\limits_{i = 1}^n (1 - a_i x)}\)。因为 \(\sum\limits_{i = 1}^n \frac{f_i}{1 - a_i x} = \frac{\sum\limits_{i = 1}^n f_i \prod\limits_{j \ne i} (1 - a_j x)}{\prod\limits_{i = 1}^n (1 - a_i x)}\) 所以可以列出 \(n\) 个关于 \(f_i\) 的方程,可以高斯消元求解。
那么 \([x^q] \frac{1}{\prod\limits_{i = 1}^n (1 - a_i x)} = [x^q] \sum\limits_{i = 1}^n \frac{f_i}{1 - a_i x} = \sum\limits_{i = 1}^n f_i ([x^q] \sum\limits_{j \ge 0} (a_i x)^j) = \sum\limits_{i = 1}^n f_i a_i^q\)。
总时间复杂度 \(O(n (2^n \log m + n^2))\)。
code
// Problem: Ex - Card Deck Score
// Contest: AtCoder - AtCoder Beginner Contest 241(Sponsored by Panasonic)
// URL: https://atcoder.jp/contests/abc241/tasks/abc241_h
// Memory Limit: 1024 MB
// Time Limit: 3000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 20;
const ll mod = 998244353;
inline ll qpow(ll b, ll p) {
ll res = 1;
while (p) {
if (p & 1) {
res = res * b % mod;
}
b = b * b % mod;
p >>= 1;
}
return res;
}
ll n, m, a[maxn], b[maxn], c[maxn];
typedef vector<ll> poly;
inline poly operator * (const poly &a, const poly &b) {
int n = (int)a.size() - 1, m = (int)b.size() - 1;
poly res(n + m + 1);
for (int i = 0; i <= n; ++i) {
for (int j = 0; j <= m; ++j) {
res[i + j] = (res[i + j] + a[i] * b[j]) % mod;
}
}
return res;
}
ll f[maxn][maxn];
poly F[maxn];
void solve() {
scanf("%lld%lld", &n, &m);
for (int i = 1; i <= n; ++i) {
scanf("%lld%lld", &a[i], &b[i]);
c[i] = qpow(a[i] % mod, b[i] + 1);
}
for (int i = 1; i <= n; ++i) {
F[i] = poly(1, 1);
for (int j = 1; j <= n; ++j) {
if (i != j) {
poly p;
p.pb(1);
p.pb(mod - a[j]);
F[i] = F[i] * p;
}
}
}
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= n; ++j) {
f[i][j] = F[j][i - 1];
}
f[i][n + 1] = (i == 1);
}
for (int i = 1; i <= n; ++i) {
for (int j = i; j <= n; ++j) {
if (f[j][i]) {
swap(f[i], f[j]);
break;
}
}
ll t = qpow(f[i][i], mod - 2);
for (int j = i; j <= n + 1; ++j) {
f[i][j] = f[i][j] * t % mod;
}
for (int j = 1; j <= n; ++j) {
if (i == j) {
continue;
}
ll t = (mod - f[j][i]) % mod;
for (int k = i + 1; k <= n + 1; ++k) {
f[j][k] = (f[j][k] + t * f[i][k]) % mod;
}
}
}
ll ans = 0;
for (int S = 0; S < (1 << n); ++S) {
ll p = 1, q = 0;
for (int i = 1; i <= n; ++i) {
if (S & (1 << (i - 1))) {
p = (mod - p * c[i] % mod) % mod;
q += b[i] + 1;
}
}
if (q > m) {
continue;
}
for (int i = 1; i <= n; ++i) {
ans = (ans + p * f[i][n + 1] % mod * qpow(a[i], m - q)) % mod;
}
}
printf("%lld\n", ans);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}