luogu P5644 [PKUWC2018]猎人杀
https://www.luogu.com.cn/problem/P5644
好题啊好题!!
首先考虑鞭尸,即可以打死去的人,但是还要继续开枪,直到打到活人
设 t o t = ∑ w i , 则 p i = w i t o t \large tot=\sum w_i,则p_i=\frac{w_i}{tot} tot=∑wi,则pi=totwi
设
F
(
S
)
表
示
恰
好
集
合
S
中
的
人
在
1
之
后
被
打
死
的
概
率
F(S)表示恰好集合S中的人在1之后被打死的概率
F(S)表示恰好集合S中的人在1之后被打死的概率
设
G
(
S
)
表
示
钦
定
集
合
S
中
的
人
在
1
之
后
被
打
死
的
概
率
G(S)表示 \ 钦定 \ 集合S中的人在1之后被打死的概率
G(S)表示 钦定 集合S中的人在1之后被打死的概率
然后考虑子集反演可得
A
N
S
=
F
(
∅
)
=
∑
(
−
1
)
∣
S
∣
G
(
S
)
ANS = F(\empty)=\sum(-1)^{|S|}G(S)
ANS=F(∅)=∑(−1)∣S∣G(S)
问题变为如何计算G
由于可以鞭尸,所以变成无限开枪,每次可以打到集合外的人,直到打到1为止
G
(
S
)
=
∑
i
=
0
∞
(
t
o
t
−
w
1
−
s
u
m
(
S
)
t
o
t
)
i
w
1
t
o
t
\large G(S)=\sum\limits_{i=0}^{\infin}(\frac{tot-w_1-sum(S)}{tot})^i\frac{w_1}{tot}
G(S)=i=0∑∞(tottot−w1−sum(S))itotw1
前面那个显然是收敛的,用等比数列求和可得
G
(
S
)
=
w
1
w
1
+
s
u
m
(
S
)
G(S)=\frac{w_1}{w_1+sum(S)}
G(S)=w1+sum(S)w1
带入回去可得
A
N
S
=
∑
(
−
1
)
∣
S
∣
w
1
w
1
+
s
u
m
(
S
)
ANS=\sum(-1)^{|S|}\frac{w_1}{w_1+sum(S)}
ANS=∑(−1)∣S∣w1+sum(S)w1
枚举子集显然起飞,不太行
考虑枚举
s
u
m
(
S
)
sum(S)
sum(S),把前面的容斥系数求个和即可
设
C
(
i
)
=
∑
s
u
m
(
S
)
=
i
(
−
1
)
∣
S
∣
C(i)=\sum\limits_{sum(S)=i} (-1)^{|S|}
C(i)=sum(S)=i∑(−1)∣S∣
A
N
S
=
∑
i
=
0
t
o
t
C
(
i
)
w
1
w
1
+
i
\large ANS=\sum\limits_{i=0}^{tot}C(i)\frac{w_1}{w_1+i}
ANS=i=0∑totC(i)w1+iw1
C 显 然 就 是 = ∏ i = 2 n ( 1 − x w i ) \large C显然就是=\prod\limits_{i=2}^n (1-x^{w_i}) C显然就是=i=2∏n(1−xwi)
分治FFT即可
code:
#include<bits/stdc++.h>
#define V vector<int>
#define N 800050
#define mod 998244353
using namespace std;
int add(int x, int y) { x += y;
if(x >= mod) x -= mod;
return x;
}
int sub(int x, int y) { x -= y;
if(x < 0) x += mod;
return x;
}
int mul(int x, int y) {
return 1ll * x * y % mod;
}
int qpow(int x, int y) {
int ret = 1;
for(; y; y >>= 1, x = mul(x, x)) if(y & 1) ret = mul(ret, x);
return ret;
}
const int G = 3;
const int G_inv = qpow(3, mod - 2);
int rev[N], w[N];
void ntt(V &a, int n, int o) {
while(a.size() < n) a.push_back(0);
for(int i = 1; i < n; i ++) if(i > rev[i]) swap(a[i], a[rev[i]]);
for(int len = 2; len <= n; len <<= 1) {
int w0 = qpow(o == 1? G : G_inv, (mod - 1) / len);
for(int j = 0; j < n; j += len) {
int wn = 1;
for(int k = j; k < j + (len >> 1); k ++, wn = mul(wn, w0)) {
int X = a[k], Y = mul(wn, a[k + (len >> 1)]);
a[k] = add(X, Y), a[k + (len >> 1)] = sub(X, Y);
}
}
}
int ninv = qpow(n, mod - 2);
if(o == -1)
for(int i = 0; i < n; i ++) a[i] = mul(a[i], ninv);
}
void prt(V C) {
printf("%d\n", (int)C.size());
for(int i = 0; i < C.size(); i ++) printf("%d ", C[i]); printf("\n");
}
V cdq(int l, int r) {
if(l == r) {
V a; a.clear();
a.push_back(1);
for(int i = 1; i < w[l]; i ++) a.push_back(0);
a.push_back(mod - 1);
return a;
}
int mid = (l + r) >> 1;
V a = cdq(l, mid), b = cdq(mid + 1, r);
// prt(a), prt(b);
int n = a.size() - 1, m = b.size() - 1, len = 1;
for(; len <= n + m; ) len <<= 1;
for(int i = 1; i < len; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) * (len >> 1));
ntt(a, len, 1), ntt(b, len, 1);
// prt(a), prt(b);
for(int i = 0; i < len; i ++) a[i] = mul(a[i], b[i]);
ntt(a, len, - 1);
while(a.size() > n + m + 1) a.pop_back();
// prt(a);
// printf("----%d %d--\n", l, r);
return a;
}
int n;
int main() {
scanf("%d", &n); int tot = 0;
for(int i = 1; i <= n; i ++) scanf("%d", &w[i]), tot += w[i];
tot -= w[1];
V C = cdq(2, n); int ans = 0;
// prt(C);
//printf("** %d\n", tot);
for(int i = 0; i <= tot; i ++) ans = add(ans, mul(C[i], mul(w[1], qpow(add(w[1], i), mod - 2))));
printf("%d", ans);
return 0;
}