「PKUWC2018」猎人杀(分治NTT+概率期望)
Description
猎人杀是一款风靡一时的游戏“狼人杀”的民间版本,他的规则是这样的:
一开始有 \(n\) 个猎人,第 \(i\) 个猎人有仇恨度 \(w_i\) ,每个猎人只有一个固定的技能:死亡后必须开一枪,且被射中的人也会死亡。
然而向谁开枪也是有讲究的,假设当前还活着的猎人有 \([i_1,i_2,...,i_m]\),那么有 \(\frac{w_{i_k}}{\sum_{j=1}^nw_{i_j}}\) 的概率是向猎人 \(k\) 开枪。
一开始第一枪由你打响,目标的选择方法和猎人一样(即有 \(\frac{w_i}{\sum_{j=1}^nw_j}\) 的概率射中第 \(i\) 个猎人)。由于开枪导致的连锁反应,所有猎人最终都会死亡,现在 \(1\) 号猎人想知道它是最后一个死的的概率。
答案对 \(998244353\) 取模。
【输入格式】
第一行一个正整数 \(n\) ;
第二行 \(n\) 个正整数,第 \(i\) 个正整数表示 \(w_i\)。
【输出格式】
输出一个非负整数表示答案。
【输入样例】
3
1 1 2
【输出样例】
915057324
【样例解释】
答案是 \(\frac{2}{4}×\frac{1}{2}+\frac{1}{4}×\frac{2}{3}=\frac{5}{12}\)。
【数据规模与约定】
对于 \(10\%\) 的数据,有 \(1\leq n\leq 10\)。
对于 \(30\%\) 的数据,有 \(1\leq n\leq 20\)。
对于 \(50\%\) 的数据,有 \(1\leq \sum\limits_{i=1}^{n}w_i\leq 5000\)。
另有 \(10\%\) 的数据,满足 \(1\leq w_i\leq 2\),且 \(w_1=1\)。
另有 \(10\%\) 的数据,满足 \(1\leq w_i\leq 2\),且 \(w_1=2\)。
对于 \(100\%\) 的数据,有 \(w_i>0\),且 \(1\leq \sum\limits_{i=1}^{n}w_i \leq 100000\)。
Solution
考虑容斥,即枚举强制在 \(1\) 号之后死的人。设 \(T\) 为枚举到的人的集合,\(S\) 为 \(T\) 中的 \(w_i\) 之和。
考虑怎么求 \(T\) 中的人都在 \(1\) 号之后死的概率。可以将它们合并成 \(0\) 号猎人,\(w_0=S\)。那么现在 \(\lceil\) \(0\) 号在 \(1\) 号之后死的概率 \(\rfloor\) 就是 \(\lceil\) \(T\) 中的人都在 \(1\) 号之后死的概率 \(\rfloor\)。显然 \(0\) 号和 \(1\) 号谁先死不受其它猎人影响,那么 \(\lceil\) \(0\) 号在 \(1\) 号之后死的概率 \(\rfloor\) 就是 \(\frac{w_1}{S+w_1}\),所以 \(\lceil\) \(T\) 中的人都在 \(1\) 号之后死的概率 \(\rfloor\) 也是 \(\frac{w_1}{S+w_1}\)。
集合 \(T\) 对答案的贡献为 \((-1)^{|T|}×\frac{w_1}{S+w_1}\)。
发现 \(\sum w_i \leq 10^5\),考虑对于每个 \(S\),求出 \(b_S\) 表示满足\(w_i\) 之和为 \(S\) 的集合 \(T\) 的 \((-1)^T\) 之和。 那么 \(ans=\sum b_S×\frac{w_1}{S+w_1}\)。
显然 \(b_S\) 就是多项式 \(\Pi _{i=2}^n(1-x^{w_i})\) 中 \(x^S\) 项的系数,分治 \(\text{NTT}\) 即可。
设 \(m=\sum_{i=1}^n w_i\),时间复杂度 \(O(m \log m)\)。
Code
#include <bits/stdc++.h>
using namespace std;
#define ll long long
template <class t>
inline void read(t & res)
{
char ch;
while (ch = getchar(), !isdigit(ch));
res = ch - 48;
while (ch = getchar(), isdigit(ch))
res = res * 10 + (ch ^ 48);
}
const int e = 2e5 + 5, mod = 998244353;
vector<int>g[e];
int rev[e], n, ans, val[e], lim;
inline int ksm(int x, int y)
{
int res = 1;
while (y)
{
if (y & 1) res = (ll)res * x % mod;
y >>= 1;
x = (ll)x * x % mod;
}
return res;
}
inline void upt(int &x, int y)
{
x = y;
if (x >= mod) x -= mod;
}
inline void fft(int *a, int n, int op)
{
int i, j, k, r = (op == 1 ? 3 : (mod + 1) / 3);
for (i = 0; i < n; i++)
if (i < rev[i]) swap(a[i], a[rev[i]]);
for (k = 1; k < n; k <<= 1)
{
int w0 = ksm(r, (mod - 1) / (k << 1));
for (i = 0; i < n; i += (k << 1))
{
int w = 1;
for (j = 0; j < k; j++)
{
int b = a[i + j], c = (ll)w * a[i + j + k] % mod;
upt(a[i + j], b + c);
upt(a[i + j + k], b + mod - c);
w = (ll)w * w0 % mod;
}
}
}
}
inline void modify(int *a, int *b, int la, int lb)
{
int i;
fft(a, lim, 1);
fft(b, lim, 1);
for (i = 0; i < lim; i++) a[i] = (ll)a[i] * b[i] % mod;
fft(a, lim, -1);
int tot = ksm(lim, mod - 2);
for (i = 0; i < la + lb - 1; i++) a[i] = (ll)a[i] * tot % mod;
}
inline void solve(int l, int r)
{
if (l >= r) return;
int mid = l + r >> 1;
solve(l, mid); solve(mid + 1, r);
int i, la = g[l].size(), lb = g[mid + 1].size();
static int a[e], b[e];
int k = 0; lim = 1;
while (lim < la + lb - 1) lim <<= 1, k++;
for (i = 0; i < lim; i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << k - 1), a[i] = b[i] = 0;
for (i = 0; i < la; i++) a[i] = g[l][i];
for (i = 0; i < lb; i++) b[i] = g[mid + 1][i];
modify(a, b, la, lb);
g[l].resize(la + lb - 1);
for (i = 0; i < la + lb - 1; i++) g[l][i] = a[i];
}
int main()
{
int i, sum = 0;
read(n);
for (i = 1; i <= n; i++) read(val[i]), sum += val[i];
sum -= val[1];
g[1].push_back(1);
for (i = 2; i <= n; i++)
{
g[i].resize(val[i] + 1);
g[i][0] = 1;
g[i][val[i]] = mod - 1;
}
solve(1, n);
for (i = val[1]; i <= sum + val[1]; i++)
{
int x = i - val[1], inv = ksm(i, mod - 2);
ans = (ans + (ll)g[1][x] * inv) % mod;
}
ans = (ll)ans * val[1] % mod;
cout << ans << endl;
fclose(stdin);
fclose(stdout);
return 0;
}