Loading

【题解】P5644 [PKUWC2018]猎人杀

供题人是树剖姐姐喵 /se

思路

生成函数 + 子集反演 + 分治 NTT.

首先发现当前打中的猎人倒下之后,后面的猎人被射中的概率会随之变化,也就是说操作是有后效性的,不好处理。

有一个高明的转化思路是:考虑转化成每次操作向所有猎人按初始概率开枪,一直到射中当前存活的猎人才结束一次操作。

(一般这种可以无限操作的问题都要无穷等比数列求和)

直觉上是对的,因为每个存活猎人死亡的概率仍然是它的权重与所有存活猎人权重之和的比值。这样每个猎人每次被射中的概率就是定值。

感觉上只能状压记录存活的猎人,否则很难转移状态。因为需要特殊判断 1 号猎人的状态,所以可以把 1 号猎人从状态中分离出来。

1 号猎人最后死亡,等价于 1 号猎人之后死亡的猎人集合 “恰好” 是空集。所以考虑容斥一下。

正常容斥会想到二项式反演,这里求的是集合,就用子集反演替代。

\(F(S)\) 为钦定集合 \(S\) 中的猎人在 1 号猎人之后死亡,其余猎人死亡顺序随意的概率,\(G(S)\) 为 1 号猎人之后死亡的猎人集合恰为 \(S\) 的概率。

有:\(F(S) = \sum\limits_{S \subseteq T} G(T)\).

根据子集反演有:\(G(S) = \sum\limits_{S \subseteq T} (-1)^{|T| - |S|} F(T)\).

答案求的是:\(G(\emptyset) = \sum\limits_{S} (-1)^{|S|} F(S)\).

现在的问题是求出 \(F(S)\).

这里又可以转化:原问题等价于开枪任意多次,每次都打中除 1 号猎人和 \(S\) 中猎人意以外的人,最终打中 1 号猎人的概率。

\(\operatorname{w}(S) = \sum\limits_{k \in S} w_k, N = \sum\limits_{i = 1}^n w_i\).

考虑枚举开枪次数,可以得到 \(F(S) = \sum\limits_{i = 0}^{+ \infty} (\frac{N - w_1 - \operatorname{w}(S)}{N})^i \frac{w_1}{N}\).

\(F(S)\) 应用等比数列求和公式得:\(\sum\limits_{i = 0}^{+ \infty} (\frac{N - w_1 - \operatorname{w}(S)}{N})^i = \frac{1}{1 - \frac{N - w_1 - \operatorname{w}(S)}{N}} = \frac{N}{w_1 + \operatorname{w}(S)}\).

于是 \(F(S) = \frac{w_1}{N} \cdot \frac{N}{w_1 + \operatorname{w}(S)} = \frac{w_1}{w_1 + \operatorname{w}(S)}\).

代入 \(G\)\(G(\emptyset) = \sum\limits_{S} (-1)^{|S|} \frac{w_1}{w_1 + \operatorname{w}(S)}\).

看起来需要枚举子集,复杂度是指数级别的。

但是注意到这个式子的取值主要和 \(\operatorname{w}(S)\) 有关,并且题目保证 \(\operatorname{w}(S) \leq 10^5\),所以可以枚举 \(\operatorname{w}(S)\) 的取值。

问题变成:钦定 \(\operatorname{w}(S)\) 的取值 \(k\),求 \(C(k) = \sum\limits_{\operatorname{w}(S) = k} (-1)^{|S|}\).

这样答案就是 \(\sum\limits_{i = 0}^{N} C(i) \frac{w_1}{w_1 + i}\).

也就是一个类似背包的无标号计数问题。

考虑令第 \(i\) 个猎人的 OGF 为 \(1 - w_i\),直接把所有猎人的 OGF 卷起来就可以得到答案。

因为 \(\operatorname{w}(S) \leq 10^5\),考虑分治 NTT 合并 \(n\) 个多项式的复杂度大约是 \(O(n \log^2 n)\).

代码

#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std;

typedef long long ll;

const int maxn = 2.7e5 + 5;
const int mod = 998244353;
const int g = 3;

int n, a;
int rev[maxn], len[maxn], pos[maxn];
ll F[maxn], Ft[maxn], Gt[maxn], wp[maxn];

void calc_rev(int k) { for (int i = 1; i < k; i++) rev[i] = (rev[i >> 1] >> 1 | (i & 1 ? k >> 1 : 0)); }

ll qpow(ll base, ll power, ll mod)
{
    ll res = 1;
    while (power)
    {
        if (power & 1) res = res * base % mod;
        base = base * base % mod;
        power >>= 1;
    }
    return res;
}

void NTT(ll *A, int n)
{
    calc_rev(n);
    for (int i = 1; i < n; i++)
        if (rev[i] > i) swap(A[i], A[rev[i]]);
    for (int len = 2, m = 1; len <= n; m = len, len <<= 1)
    {
        ll wn = qpow(g, (mod - 1) / len, mod);
        wp[0] = 1;
        for (int i = 1; i <= len; i++) wp[i] = wp[i - 1] * wn % mod;
        for (int l = 0, r = len - 1; r <= n; l += len, r += len)
        {
            int w = 0;
            for (int p = l; p < l + m; p++, w++)
            {
                ll x = A[p], y = wp[w] * A[p + m] % mod;
                A[p] = (x + y) % mod, A[p + m] = (x - y + mod) % mod;
            }
        }
    }
}

void INTT(ll *A, int n)
{
    NTT(A, n);
    reverse(A + 1, A + n);
    int inv = qpow(n, mod - 2, mod);
    for (int i = 0; i < n; i++) A[i] = 1ll * A[i] * inv % mod;
}

void times(ll *F, ll *G, int lenf, int leng)
{
    int k = 1, lim = lenf + leng - 1;
    while (k < lim) k <<= 1;
    for (int i = 0; i < lenf; i++) Ft[i] = F[i], F[i] = 0;
    for (int i = 0; i < leng; i++) Gt[i] = G[i], G[i] = 0;
    NTT(Ft, k), NTT(Gt, k);
    for (int i = 0; i < k; i++) Ft[i] = Ft[i] * Gt[i] % mod;
    INTT(Ft, k);
    for (int i = 0; i < lim; i++) F[i] = Ft[i];
    for (int i = 0; i < k; i++) Ft[i] = Gt[i] = 0;
}

int main()
{
    scanf("%d%d", &n, &a), n--;
    if (n == 0) return puts("1"), 0;
    for (int i = 1; i <= n; i++)
    {
        scanf("%d", &len[i]), len[i]++;
        pos[i + 1] = pos[i] + len[i];
        F[pos[i]] = 1, F[pos[i + 1] - 1] = mod - 1;
    }
    int tot = pos[n + 1];
    while (n > 1)
    {
        for (int i = 1; i + 1 <= n; i += 2)
        {
            times(&F[pos[i]], &F[pos[i + 1]], len[i], len[i + 1]);
            len[(i + 1) >> 1] = len[i] + len[i + 1] - 1, pos[(i + 1) >> 1] = pos[i];
        }
        if (n & 1) len[(n + 1) >> 1] = len[n], pos[(n + 1) >> 1] = pos[n], n++;
        n >>= 1;
    }
    ll ans = 0;
    for (int i = 0; i <= tot; i++) ans = (ans + F[i] * a % mod * qpow(i + a, mod - 2, mod)) % mod;
    printf("%lld\n", ans);
    return 0;
}
posted @ 2023-02-20 14:35  kymru  阅读(39)  评论(0编辑  收藏  举报