P5644(容斥)
看不懂其它题解的智慧证明/kk,于是自己口胡了一个。
考虑钦定一些人在 \(i\) 后面,然后容斥。现在问题转变成了求对于 \(S\),\(\{1\}\cup S\) 中 \(1\) 是第一个的方案数。
考虑 \(S=\{2,3,\cdots,n\}\),发现概率是 \(w_1/\sum_{i=1}^n w_i\),很自然地猜想对于集合 \(S\),概率是:
\[\frac{w_1}{w_1+\sum_{x\in S}w_x}
\]
这很符合我们的直觉。
假设 \(P(i,S)\) 表示 \(i\) 在集合 \(S\) 中第一个出现的概率,那么证明这个式子几乎等价于证明:
\[\forall i,j,\frac{P(i,S)}{P(j,S)}=\frac{w_i}{w_j}
\]
考虑 \(P(i,S)\) 的式子长什么样。枚举 \(i\) 前面的排列 \(T\),然后钦定 \(i\) 在 \(T\) 后面恰好一个寄掉,得到:
\[\sum_{T\cap S=\varnothing}f(T)\frac{w_i}{\sum_{j=1}^n w_j-\sum_{x\in T}w_x-w_i}
\]
\(f(T)\) 是前面的贡献,此处省去。对比 \(P(i,S)\) 和 \(P(j,S)\) 的式子不难证明。
然后实际上只需要求:
\[1+\sum_S (-1)^{|S|}\frac{w_1}{w_1+\sum_{x\in S}w_x}
\]
考虑到式子只和 \(\sum_{x\in S}w_x\) 和 \(|S|\) 有关,那么只需要做个背包就好了,分治+NTT 可以 \(O(n\log^2 n)\)。
两边分治的时候可以开 \(32\) 个数组,数组里面存多项式。有点像滚动数组的操作。
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define db double
#define ldb long double
#define pb push_back
#define mp make_pair
#define pii pair<int, int>
using namespace std;
inline int read() {
int x = 0; bool op = 0;
char c = getchar();
while(!isdigit(c))op |= (c == '-'), c = getchar();
while(isdigit(c))x = (x << 1) + (x << 3) + (c ^ 48), c = getchar();
return op ? -x : x;
}
pii aprx(int p, int q, int A) {
int x = q, y = p, a = 1, b = 0;
while(x > A) {
swap(x, y); swap(a, b);
a -= (x / y) * b; x %= y;
}
return mp(x, a);
}
const int N = (1 << 20) + 10;
const int P = 998244353;
void add(int &a, int b) {a = (a + b) % P;}
void sub(int &a, int b) {a = (a - b + P) % P;}
int ksm(int x, int k) {
int res = 1;
for(int pw = x; k; (k & 1) ? res = 1ll * res * pw % P : 0, pw = 1ll * pw * pw % P, k >>= 1);
return res;
}
int n;
int w[N];
int fac[N], ifac[N], pw[N], ipw[N];
void init() {
fac[0] = ifac[0] = 1;
for(int i = 1; i < N; i++)fac[i] = 1ll * fac[i - 1] * i % P;
ifac[N - 1] = ksm(fac[N - 1], P - 2);
for(int i = N - 2; i; i--)ifac[i] = 1ll * ifac[i + 1] * (i + 1) % P;
for(int i = 1; i <= 20; i++)pw[1 << i] = ksm(3, (P - 1) / (1 << i));
for(int i = 1; i <= 20; i++)ipw[1 << i] = ksm(pw[1 << i], P - 2);
return ;
}
int rev[N];
int inv(int x) {return 1ll * ifac[x] * fac[x - 1] % P;}
void NTT(int *f, int len, int op) {
for(int i = 0; i < len; i++)rev[i] = (rev[i >> 1] >> 1) | ((i & 1) * (len >> 1));
for(int i = 0; i < len; i++)if(i < rev[i])swap(f[i], f[rev[i]]);
for(int k = 2; k <= len; k <<= 1) {
for(int i = 0; i < len; i += k) {
int gn = (op == -1) ? ipw[k] : pw[k];
// printf("gn:%d %d\n", k, gn);
for(int j = i, g = 1; j < i + k / 2; j++, g = 1ll * g * gn % P) {
int x = f[j], y = 1ll * g * f[j + k / 2] % P;
f[j] = (x + y) % P; f[j + k / 2] = (x - y + P) % P;
}
}
}
if(op == -1) {
for(int i = 0; i < len; i++)f[i] = 1ll * f[i] * inv(len) % P;
}
return ;
}
int f[N], g[35][N], stk[N], top = 33;
int solve(int cur, int l, int r) {
if(l == r)return g[cur][w[l]] = P - 1, g[cur][0] = 1, w[l];
int mid = l + r >> 1, len = 0, ls = stk[top--], rs = stk[top--];
len += solve(ls, l, mid); len += solve(rs, mid + 1, r);
int ln = 1; while(ln < len + 5)ln <<= 1;
NTT(g[ls], ln, 1); NTT(g[rs], ln, 1);
for(int i = 0; i < ln; i++)g[cur][i] = 1ll * g[ls][i] * g[rs][i] % P;
NTT(g[cur], ln, -1);
for(int i = 0; i < ln; i++)g[ls][i] = g[rs][i] = 0;
stk[++top] = ls; stk[++top] = rs;
return len;
}
int ans[N];
int main() {
init();
n = read();
for(int i = 1; i <= n; i++)w[i] = read();
for(int i = 1; i <= 33; i++)stk[i] = i;
solve(0, 2, n);
int res = 1;
for(int i = 1; i <= 100000; i++) {
add(res, 1ll * g[0][i] * w[1] % P * inv(i + w[1]) % P);
}
printf("%d\n", res);
return 0;
}