AT5202 [AGC038E] Gachapon(min-max)
AT5202 [AGC038E] Gachapon(min-max)
题目大意
有一个随机数生成器,生成 \([0,n-1]\) 之间的整数,其中生成 \(i\) 的概率为 \(\frac{A_i}{S}\),其中,\(S=\sum A_i\)。
这个随机数生成器不断生成随机数,当 \(\forall i\in[0,n-1]\),\(i\) 至少出现了 \(B_i\) 次时,停止生成,否则继续生成。
求期望生成随机数的次数,输出答案对 \(998244353\) 取模的结果。
数据范围
\(A_i,B_i\geq 1\),\(\sum A_i,\sum B_i,n\leq 400\)。
解题思路
显然是一个 min-max 反演
\[Ans = \sum_{T \subseteq S}(-1)^{|T|+1}\frac {S}{\sum_{i\in T}A_i}f(T)
\]
其中,\(f(T)\) 表示 T 集合中第一个至少出现了 \(B_i\) 次的期望次数。
考虑暴力求 T 集合的答案
\[f(T) = \sum_{i=1}P(x=i)\times i=\sum_{i=0}^{sumB}P(x > i)
\]
如何求出 \(P(x>i)\) 呢?考虑用方案数除以总方案数,方案数就是一个背包问题,用生成函数表示是
\[f(T)=\sum_{i=0}\left[\frac {x^i}{i!}\right]\left(\prod_{j\in T}\sum_{t=0}^{B_j-1}A_j^t\frac {x^t}{t!}\right)
\]
容易发现我们只用一维即可,时间复杂度是 \(\Theta(n^2)\) 的
观察发现只要选中的生成函数不会变,而且前面的 \((-1)^{|T|+1}\) 可以乘进去,又发现 \(S\) 很小,我们用另一维状态去压缩它即可,时间复杂度 \(\Theta(n^3)\),最后统计答案即可。
/*
/> フ
| _ _|
/`ミ _x 彡
/ |
/ ヽ ?
/ ̄| | | |
| ( ̄ヽ__ヽ_)_)
\二つ
*/
#include <queue>
#include <vector>
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define MP make_pair
#define ll long long
#define fi first
#define se second
using namespace std;
template <typename T>
void read(T &x) {
x = 0; bool f = 0;
char c = getchar();
for (;!isdigit(c);c=getchar()) if (c=='-') f=1;
for (;isdigit(c);c=getchar()) x=x*10+(c^48);
if (f) x=-x;
}
template<typename F>
inline void write(F x, char ed = '\n') {
static short st[30];short tp=0;
if(x<0) putchar('-'),x=-x;
do st[++tp]=x%10,x/=10; while(x);
while(tp) putchar('0'|st[tp--]);
putchar(ed);
}
template <typename T>
inline void Mx(T &x, T y) { x < y && (x = y); }
template <typename T>
inline void Mn(T &x, T y) { x > y && (x = y); }
const int P = 998244353;
const int N = 405;
ll inv[N], fac[N], a[N], b[N], A, B, ans, n;
ll fpw(ll x, ll mi) {
ll res = 1;
for (; mi; mi >>= 1, x = x * x % P)
if (mi & 1) res = res * x % P;
return res;
}
ll g[N][N], f[N][N];
int main() {
read(n);
inv[0] = fac[0] = inv[1] = fac[1] = 1;
for (int i = 2;i <= 400; i++) inv[i] = (P - P / i) * inv[P % i] % P;
for (int i = 2;i <= 400; i++)
inv[i] = inv[i-1] * inv[i] % P,
fac[i] = fac[i-1] * i % P;
for (int i = 1;i <= n; i++) {
read(a[i]), read(b[i]);
A += a[i], B += b[i];
}
f[0][0] = -1;
/* for (int i = 1;i <= 50; i++) write(inv[i], ' '), write(fac[i]); */
for (int i = 1;i <= n; i++) {
memcpy(g, f, sizeof(g));
for (int s = A;s >= 0; s--) {
for (int j = B;j >= 0; j--) {
if (s < a[i]) { f[s][j] = 0; continue; }
ll t = a[i];
f[s][j] = f[s-a[i]][j];
for (int k = 1;k < b[i]; k++, t = t * a[i] % P)
f[s][j] = (f[s][j] + t * inv[k] % P * f[s-a[i]][j-k]) % P;
}
}
for (int j = 0;j <= A; j++)
for (int k = 0;k <= B; k++)
f[j][k] = (g[j][k] - f[j][k] + P) % P;
}
for (int s = 1;s <= A; s++) {
ll tt = fpw(s, P - 2), t = A * tt % P;
ll res = 0;
for (int i = 0;i <= B; i++, t = t * tt % P)
res = (res + f[s][i] * fac[i] % P * t) % P;
res %= P, ans = (ans + res) % P;
}
write(ans);
return 0;
}
*/