luogu P5644 [PKUWC2018]猎人杀

https://www.luogu.com.cn/problem/P5644

好题啊好题!!
首先考虑鞭尸,即可以打死去的人,但是还要继续开枪,直到打到活人

t o t = ∑ w i , 则 p i = w i t o t \large tot=\sum w_i,则p_i=\frac{w_i}{tot} tot=wipi=totwi

F ( S ) 表 示 恰 好 集 合 S 中 的 人 在 1 之 后 被 打 死 的 概 率 F(S)表示恰好集合S中的人在1之后被打死的概率 F(S)S1
G ( S ) 表 示   钦 定   集 合 S 中 的 人 在 1 之 后 被 打 死 的 概 率 G(S)表示 \ 钦定 \ 集合S中的人在1之后被打死的概率 G(S)  S1
然后考虑子集反演可得
A N S = F ( ∅ ) = ∑ ( − 1 ) ∣ S ∣ G ( S ) ANS = F(\empty)=\sum(-1)^{|S|}G(S) ANS=F()=(1)SG(S)
问题变为如何计算G
由于可以鞭尸,所以变成无限开枪,每次可以打到集合外的人,直到打到1为止
G ( S ) = ∑ i = 0 ∞ ( t o t − w 1 − s u m ( S ) t o t ) i w 1 t o t \large G(S)=\sum\limits_{i=0}^{\infin}(\frac{tot-w_1-sum(S)}{tot})^i\frac{w_1}{tot} G(S)=i=0(tottotw1sum(S))itotw1
前面那个显然是收敛的,用等比数列求和可得
G ( S ) = w 1 w 1 + s u m ( S ) G(S)=\frac{w_1}{w_1+sum(S)} G(S)=w1+sum(S)w1
带入回去可得
A N S = ∑ ( − 1 ) ∣ S ∣ w 1 w 1 + s u m ( S ) ANS=\sum(-1)^{|S|}\frac{w_1}{w_1+sum(S)} ANS=(1)Sw1+sum(S)w1
枚举子集显然起飞,不太行
考虑枚举 s u m ( S ) sum(S) sum(S),把前面的容斥系数求个和即可
C ( i ) = ∑ s u m ( S ) = i ( − 1 ) ∣ S ∣ C(i)=\sum\limits_{sum(S)=i} (-1)^{|S|} C(i)=sum(S)=i(1)S
A N S = ∑ i = 0 t o t C ( i ) w 1 w 1 + i \large ANS=\sum\limits_{i=0}^{tot}C(i)\frac{w_1}{w_1+i} ANS=i=0totC(i)w1+iw1

C 显 然 就 是 = ∏ i = 2 n ( 1 − x w i ) \large C显然就是=\prod\limits_{i=2}^n (1-x^{w_i}) C=i=2n(1xwi)

分治FFT即可

code:


#include<bits/stdc++.h>
#define V vector<int>
#define N 800050
#define mod 998244353
using namespace std;
int add(int x, int y) { x += y;
    if(x >= mod) x -= mod;
    return x;
}
int sub(int x, int y) { x -= y;
    if(x < 0) x += mod;
    return x;
}
int mul(int x, int y) {
    return 1ll * x * y % mod;
}
int qpow(int x, int y) {
    int ret = 1;
    for(; y; y >>= 1, x = mul(x, x)) if(y & 1) ret = mul(ret, x);
    return ret;
}
const int G = 3;
const int G_inv = qpow(3, mod - 2);
int rev[N], w[N];
void ntt(V &a, int n, int o) {
    while(a.size() < n) a.push_back(0);
    for(int i = 1; i < n; i ++) if(i > rev[i]) swap(a[i], a[rev[i]]);
    for(int len = 2; len <= n; len <<= 1) {
        int w0 = qpow(o == 1? G : G_inv, (mod - 1) / len);
        for(int j = 0; j < n; j += len) {
            int wn = 1;
            for(int k = j; k < j + (len >> 1); k ++, wn = mul(wn, w0)) {
                int X = a[k], Y = mul(wn, a[k + (len >> 1)]);
                a[k] = add(X, Y), a[k + (len >> 1)] = sub(X, Y);
            }
        }
    }
    int ninv = qpow(n, mod - 2);
    if(o == -1)
        for(int i = 0; i < n; i ++) a[i] = mul(a[i], ninv);
}
void prt(V C) {
    printf("%d\n", (int)C.size());
    for(int i = 0; i < C.size(); i ++) printf("%d ", C[i]); printf("\n");
}
V cdq(int l, int r) {
    if(l == r) {
        V a; a.clear();
        a.push_back(1);
        for(int i = 1; i < w[l]; i ++) a.push_back(0);
        a.push_back(mod - 1);
        return a;
    }
    int mid = (l + r) >> 1;
    V a = cdq(l, mid), b = cdq(mid + 1, r);
//    prt(a), prt(b);
    int n = a.size() - 1, m = b.size() - 1, len = 1;
    for(; len <= n + m; ) len <<= 1;
    for(int i = 1; i < len; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) * (len >> 1));
    ntt(a, len, 1), ntt(b, len, 1);
  //  prt(a), prt(b);
    for(int i = 0; i < len; i ++) a[i] = mul(a[i], b[i]);
    ntt(a, len, - 1);
    while(a.size() > n + m + 1) a.pop_back();
//    prt(a);
 //   printf("----%d %d--\n", l, r);
    return a;
}
int n;
int main() {
    scanf("%d", &n); int tot = 0;
    for(int i = 1; i <= n; i ++) scanf("%d", &w[i]), tot += w[i];
    tot -= w[1];
    V C = cdq(2, n); int ans = 0;
  //  prt(C);
    //printf("** %d\n", tot);
    for(int i = 0; i <= tot; i ++) ans = add(ans, mul(C[i], mul(w[1], qpow(add(w[1], i), mod - 2))));
    printf("%d", ans);
    return 0;
}

posted @ 2021-07-30 19:10  lahlah  阅读(49)  评论(0编辑  收藏  举报