「PKUWC2018」猎人杀

题目

好题好题

这个分母一直在变,看上去完全不知道怎么去算

考虑容斥一波,设\(g_i\)表示至少有\(i\)死在\(1\)之后的概率,那么答案就是\(\sum_{i=0}^n(-1)^ig_i\)

考虑一下\(g_i\)怎么算,发现又不会算了

考虑一个问题,就是现在有两个人\(i,j\),其中\(i\)\(j\)先死的概率是多少

这个概率是\(\frac{w_i}{w_i+w_j}\),瞎猜一下大概是这个样子的

我们设\(p_i\)\(i\)\(j\)先死的概率,设\(S=\sum_{i=1}^nw_i\),那么就有

\[p_i=\frac{w_i}{S}+\frac{S-w_i-w_j}{S}p_i \]

这样解一下方程就能得到\(p_i=\frac{w_i}{w_i+w_j}\)

我们考虑刚才那个\(g_i\)还是非常不好求的样子,我们先考虑对于一个集合\(T\),这个集合\(T\)里的人都在\(1\)之后死的概率

显然我们可以把\(T\)集合里的人合并成一个人,那么容斥之后的答案就是

\[\sum_{T\subset S}(-1)^{|T|}\frac{w_1}{w_1+\sum T} \]

我们注意到\(\sum_{i=1}^nw_i \leq 10^5\),这启示我们把每一个分母出现的次数都算出来

考虑到前面还有一个容斥系数,我们可以把每个人都写成一个生成函数\(1-x^{w_i}\)

分治\(ntt\)求一下\(\prod_{i=1}^n1-x^{w_i}\)就能算每一个分母出现的次数了

代码

#include <cstdio>
#include <vector>
#include <cstring>
#include <iostream>
#include <algorithm>
#define pb push_back
#define re register
#define LL long long
#define max(a, b) ((a) > (b) ? (a) : (b))
#define min(a, b) ((a) < (b) ? (a) : (b))
const int mod = 998244353;
const int maxn = 262144 + 5;
const int G[2] = { 3, (mod + 1) / 3 };
inline int read() {
    char c = getchar();
    int x = 0;
    while (c < '0' || c > '9') c = getchar();
    while (c >= '0' && c <= '9') x = (x << 3) + (x << 1) + c - 48, c = getchar();
    return x;
}
std::vector<int> q[maxn * 3];
int n, len, ra[maxn], la[maxn], inv[maxn], a[maxn], rev[maxn], pre[maxn];
int __og[25][2];
inline int ksm(int a, int b) {
    int S = 1;
    for (; b; b >>= 1, a = 1ll * a * a % mod)
        if (b & 1)
            S = 1ll * S * a % mod;
    return S;
}
inline void NTT(int *f, int o) {
    for (re int i = 0; i < len; i++)
        if (i < rev[i])
            std::swap(f[i], f[rev[i]]);
    for (re int w = 0, i = 2; i <= len; i <<= 1, w++) {
        int ln = i >> 1;
        int og1;
        if (!__og[w][o])
            og1 = __og[w][o] = ksm(G[o], (mod - 1) / i);
        else
            og1 = __og[w][o];
        for (re int t, og = 1, l = 0; l < len; l += i, og = 1)
            for (re int x = l; x < l + ln; ++x) {
                t = 1ll * f[x + ln] * og % mod, og = 1ll * og * og1 % mod;
                f[x + ln] = (f[x] - t + mod) % mod, f[x] = (f[x] + t) % mod;
            }
    }
    if (!o)
        return;
    int Inv = ksm(len, mod - 2);
    for (re int i = 0; i < len; i++) f[i] = 1ll * f[i] * Inv % mod;
}
void cdq(int l, int r, int t) {
    if (l == r) {
        q[t].pb(1);
        for (re int i = 1; i < a[l]; i++) q[t].pb(0);
        q[t].pb(mod - 1);
        return;
    }
    int mid = l + r >> 1;
    cdq(l, mid, t << 1), cdq(mid + 1, r, t << 1 | 1);
    len = 1;
    while (len <= pre[r] - pre[l - 1]) len <<= 1;
    for (re int i = 0; i < len; i++) rev[i] = rev[i >> 1] >> 1 | ((i & 1) ? len >> 1 : 0);
    for (re int i = 0; i < q[t << 1].size(); i++) la[i] = q[t << 1][i];
    for (re int i = q[t << 1].size(); i < len; i++) la[i] = 0;
    for (re int i = 0; i < q[t << 1 | 1].size(); i++) ra[i] = q[t << 1 | 1][i];
    for (re int i = q[t << 1 | 1].size(); i < len; i++) ra[i] = 0;
    NTT(la, 0), NTT(ra, 0);
    for (re int i = 0; i < len; i++) la[i] = 1ll * la[i] * ra[i] % mod;
    NTT(la, 1);
    for (re int i = 0; i <= pre[r] - pre[l - 1]; i++) q[t].pb(la[i]);
}
inline int calc(int x) { return 1ll * a[0] * inv[a[0] + x] % mod; }
int main() {
    n = read();
    inv[1] = 1;
    for (re int i = 0; i < n; i++) a[i] = read();
    pre[0] = a[0];
    for (re int i = 1; i < n; i++) pre[i] = pre[i - 1] + a[i];
    cdq(1, n - 1, 1);
    for (re int i = 2; i <= pre[n - 1]; i++) inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod;
    int ans = 0;
    for (re int i = 0; i < q[1].size(); i++) ans = (ans + 1ll * q[1][i] * calc(i) % mod) % mod;
    printf("%d\n", ans);
    return 0;
}
posted @ 2019-06-19 21:36  asuldb  阅读(254)  评论(0编辑  收藏  举报