「PKUWC2018」猎人杀
好题好题
这个分母一直在变,看上去完全不知道怎么去算
考虑容斥一波,设\(g_i\)表示至少有\(i\)死在\(1\)之后的概率,那么答案就是\(\sum_{i=0}^n(-1)^ig_i\)
考虑一下\(g_i\)怎么算,发现又不会算了
考虑一个问题,就是现在有两个人\(i,j\),其中\(i\)比\(j\)先死的概率是多少
这个概率是\(\frac{w_i}{w_i+w_j}\),瞎猜一下大概是这个样子的
我们设\(p_i\)为\(i\)比\(j\)先死的概率,设\(S=\sum_{i=1}^nw_i\),那么就有
\[p_i=\frac{w_i}{S}+\frac{S-w_i-w_j}{S}p_i
\]
这样解一下方程就能得到\(p_i=\frac{w_i}{w_i+w_j}\)了
我们考虑刚才那个\(g_i\)还是非常不好求的样子,我们先考虑对于一个集合\(T\),这个集合\(T\)里的人都在\(1\)之后死的概率
显然我们可以把\(T\)集合里的人合并成一个人,那么容斥之后的答案就是
\[\sum_{T\subset S}(-1)^{|T|}\frac{w_1}{w_1+\sum T}
\]
我们注意到\(\sum_{i=1}^nw_i \leq 10^5\),这启示我们把每一个分母出现的次数都算出来
考虑到前面还有一个容斥系数,我们可以把每个人都写成一个生成函数\(1-x^{w_i}\)
分治\(ntt\)求一下\(\prod_{i=1}^n1-x^{w_i}\)就能算每一个分母出现的次数了
代码
#include <cstdio>
#include <vector>
#include <cstring>
#include <iostream>
#include <algorithm>
#define pb push_back
#define re register
#define LL long long
#define max(a, b) ((a) > (b) ? (a) : (b))
#define min(a, b) ((a) < (b) ? (a) : (b))
const int mod = 998244353;
const int maxn = 262144 + 5;
const int G[2] = { 3, (mod + 1) / 3 };
inline int read() {
char c = getchar();
int x = 0;
while (c < '0' || c > '9') c = getchar();
while (c >= '0' && c <= '9') x = (x << 3) + (x << 1) + c - 48, c = getchar();
return x;
}
std::vector<int> q[maxn * 3];
int n, len, ra[maxn], la[maxn], inv[maxn], a[maxn], rev[maxn], pre[maxn];
int __og[25][2];
inline int ksm(int a, int b) {
int S = 1;
for (; b; b >>= 1, a = 1ll * a * a % mod)
if (b & 1)
S = 1ll * S * a % mod;
return S;
}
inline void NTT(int *f, int o) {
for (re int i = 0; i < len; i++)
if (i < rev[i])
std::swap(f[i], f[rev[i]]);
for (re int w = 0, i = 2; i <= len; i <<= 1, w++) {
int ln = i >> 1;
int og1;
if (!__og[w][o])
og1 = __og[w][o] = ksm(G[o], (mod - 1) / i);
else
og1 = __og[w][o];
for (re int t, og = 1, l = 0; l < len; l += i, og = 1)
for (re int x = l; x < l + ln; ++x) {
t = 1ll * f[x + ln] * og % mod, og = 1ll * og * og1 % mod;
f[x + ln] = (f[x] - t + mod) % mod, f[x] = (f[x] + t) % mod;
}
}
if (!o)
return;
int Inv = ksm(len, mod - 2);
for (re int i = 0; i < len; i++) f[i] = 1ll * f[i] * Inv % mod;
}
void cdq(int l, int r, int t) {
if (l == r) {
q[t].pb(1);
for (re int i = 1; i < a[l]; i++) q[t].pb(0);
q[t].pb(mod - 1);
return;
}
int mid = l + r >> 1;
cdq(l, mid, t << 1), cdq(mid + 1, r, t << 1 | 1);
len = 1;
while (len <= pre[r] - pre[l - 1]) len <<= 1;
for (re int i = 0; i < len; i++) rev[i] = rev[i >> 1] >> 1 | ((i & 1) ? len >> 1 : 0);
for (re int i = 0; i < q[t << 1].size(); i++) la[i] = q[t << 1][i];
for (re int i = q[t << 1].size(); i < len; i++) la[i] = 0;
for (re int i = 0; i < q[t << 1 | 1].size(); i++) ra[i] = q[t << 1 | 1][i];
for (re int i = q[t << 1 | 1].size(); i < len; i++) ra[i] = 0;
NTT(la, 0), NTT(ra, 0);
for (re int i = 0; i < len; i++) la[i] = 1ll * la[i] * ra[i] % mod;
NTT(la, 1);
for (re int i = 0; i <= pre[r] - pre[l - 1]; i++) q[t].pb(la[i]);
}
inline int calc(int x) { return 1ll * a[0] * inv[a[0] + x] % mod; }
int main() {
n = read();
inv[1] = 1;
for (re int i = 0; i < n; i++) a[i] = read();
pre[0] = a[0];
for (re int i = 1; i < n; i++) pre[i] = pre[i - 1] + a[i];
cdq(1, n - 1, 1);
for (re int i = 2; i <= pre[n - 1]; i++) inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod;
int ans = 0;
for (re int i = 0; i < q[1].size(); i++) ans = (ans + 1ll * q[1][i] * calc(i) % mod) % mod;
printf("%d\n", ans);
return 0;
}