P5644(容斥)

看不懂其它题解的智慧证明/kk,于是自己口胡了一个。

考虑钦定一些人在 \(i\) 后面,然后容斥。现在问题转变成了求对于 \(S\)\(\{1\}\cup S\)\(1\) 是第一个的方案数。

考虑 \(S=\{2,3,\cdots,n\}\),发现概率是 \(w_1/\sum_{i=1}^n w_i\),很自然地猜想对于集合 \(S\),概率是:

\[\frac{w_1}{w_1+\sum_{x\in S}w_x} \]

这很符合我们的直觉。

假设 \(P(i,S)\) 表示 \(i\) 在集合 \(S\) 中第一个出现的概率,那么证明这个式子几乎等价于证明:

\[\forall i,j,\frac{P(i,S)}{P(j,S)}=\frac{w_i}{w_j} \]

考虑 \(P(i,S)\) 的式子长什么样。枚举 \(i\) 前面的排列 \(T\),然后钦定 \(i\)\(T\) 后面恰好一个寄掉,得到:

\[\sum_{T\cap S=\varnothing}f(T)\frac{w_i}{\sum_{j=1}^n w_j-\sum_{x\in T}w_x-w_i} \]

\(f(T)\) 是前面的贡献,此处省去。对比 \(P(i,S)\)\(P(j,S)\) 的式子不难证明。

然后实际上只需要求:

\[1+\sum_S (-1)^{|S|}\frac{w_1}{w_1+\sum_{x\in S}w_x} \]

考虑到式子只和 \(\sum_{x\in S}w_x\)\(|S|\) 有关,那么只需要做个背包就好了,分治+NTT 可以 \(O(n\log^2 n)\)

两边分治的时候可以开 \(32\) 个数组,数组里面存多项式。有点像滚动数组的操作。

#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define db double
#define ldb long double
#define pb push_back
#define mp make_pair
#define pii pair<int, int>
using namespace std;
inline int read() {
    int x = 0; bool op = 0;
    char c = getchar();
    while(!isdigit(c))op |= (c == '-'), c = getchar();
    while(isdigit(c))x = (x << 1) + (x << 3) + (c ^ 48), c = getchar();
    return op ? -x : x;
}
pii aprx(int p, int q, int A) {
    int x = q, y = p, a = 1, b = 0;
    while(x > A) {
        swap(x, y); swap(a, b);
        a -= (x / y) * b; x %= y;
    }
    return mp(x, a);
}
const int N = (1 << 20) + 10;
const int P = 998244353;
void add(int &a, int b) {a = (a + b) % P;}
void sub(int &a, int b) {a = (a - b + P) % P;}
int ksm(int x, int k) {
    int res = 1;
    for(int pw = x; k; (k & 1) ? res = 1ll * res * pw % P : 0, pw = 1ll * pw * pw % P, k >>= 1);
    return res;
}
int n;
int w[N];
int fac[N], ifac[N], pw[N], ipw[N];
void init() {
    fac[0] = ifac[0] = 1;
    for(int i = 1; i < N; i++)fac[i] = 1ll * fac[i - 1] * i % P;
    ifac[N - 1] = ksm(fac[N - 1], P - 2);
    for(int i = N - 2; i; i--)ifac[i] = 1ll * ifac[i + 1] * (i + 1) % P;
    for(int i = 1; i <= 20; i++)pw[1 << i] = ksm(3, (P - 1) / (1 << i));
    for(int i = 1; i <= 20; i++)ipw[1 << i] = ksm(pw[1 << i], P - 2);
    return ;
}
int rev[N];
int inv(int x) {return 1ll * ifac[x] * fac[x - 1] % P;}
void NTT(int *f, int len, int op) {
    for(int i = 0; i < len; i++)rev[i] = (rev[i >> 1] >> 1) | ((i & 1) * (len >> 1));
    for(int i = 0; i < len; i++)if(i < rev[i])swap(f[i], f[rev[i]]);
    for(int k = 2; k <= len; k <<= 1) {
        for(int i = 0; i < len; i += k) {
            int gn = (op == -1) ? ipw[k] : pw[k];
            // printf("gn:%d %d\n", k, gn);
            for(int j = i, g = 1; j < i + k / 2; j++, g = 1ll * g * gn % P) {
                int x = f[j], y = 1ll * g * f[j + k / 2] % P;
                f[j] = (x + y) % P; f[j + k / 2] = (x - y + P) % P;
            }
        }
    }
    if(op == -1) {
        for(int i = 0; i < len; i++)f[i] = 1ll * f[i] * inv(len) % P;
    }
    return ;
}
int f[N], g[35][N], stk[N], top = 33;
int solve(int cur, int l, int r) {
    if(l == r)return g[cur][w[l]] = P - 1, g[cur][0] = 1, w[l];
    int mid = l + r >> 1, len = 0, ls = stk[top--], rs = stk[top--];
    len += solve(ls, l, mid); len += solve(rs, mid + 1, r);
    int ln = 1; while(ln < len + 5)ln <<= 1;
    NTT(g[ls], ln, 1); NTT(g[rs], ln, 1);
    for(int i = 0; i < ln; i++)g[cur][i] = 1ll * g[ls][i] * g[rs][i] % P;
    NTT(g[cur], ln, -1);
    for(int i = 0; i < ln; i++)g[ls][i] = g[rs][i] = 0;
    stk[++top] = ls; stk[++top] = rs;
    return len;
}
int ans[N];
int main() {
    init();
    n = read();
    for(int i = 1; i <= n; i++)w[i] = read();
    for(int i = 1; i <= 33; i++)stk[i] = i;
    solve(0, 2, n);
    int res = 1;
    for(int i = 1; i <= 100000; i++) {
        add(res, 1ll * g[0][i] * w[1] % P * inv(i + w[1]) % P);
    }
    printf("%d\n", res);     
    return 0;
}
posted @ 2022-07-11 20:39  yllcm  阅读(58)  评论(0编辑  收藏  举报