[PKUWC2018] 猎人杀
题目描述
猎人杀是一款风靡一时的游戏“狼人杀”的民间版本,他的规则是这样的:
一开始有 \(n\) 个猎人,第 \(i\) 个猎人有仇恨度 \(w_i\) ,每个猎人只有一个固定的技能:死亡后必须开一枪,且被射中的人也会死亡。
然而向谁开枪也是有讲究的,假设当前还活着的猎人有 \([i_1\ldots i_m]\),那么有 \(\frac{w_{i_k}}{\sum_{j = 1}^{m} w_{i_j}}\) 的概率是向猎人 \(i_k\) 开枪。
一开始第一枪由你打响,目标的选择方法和猎人一样(即有 \(\frac{w_i}{\sum_{j=1}^{n}w_j}\) 的概率射中第 \(i\) 个猎人)。由于开枪导致的连锁反应,所有猎人最终都会死亡,现在 \(1\) 号猎人想知道它是最后一个死的的概率。
答案对 \(998244353\) 取模。
对于 \(100\%\) 的数据,有 \(w_i>0\),且 \(1\leq \sum\limits_{i=1}^{n}w_i \leq 100000\)
题解
非常妙的第一步。
现在这个分母会改变,很不好算。所以把题目进行一个转换:假设死了的猎人尸体仍然在那里,然后在打枪的时候如果达到一个尸体就再打一次。那么这样对于任何一个猎人,他的死亡概率仍然是一样的。做了这个转换后,一个人被打到的概率就很固定了。
然后直接算不好算,考虑容斥。定义 \(f(S)\) 为在 1 前面被打到的集合恰好为 \(S\) 的概率,\(g(S)\) 为在 1 前面死的集合包含于 \(S\) 的概率。然后 \(g(S)\) 是可算的。枚举在第 \(i+1\) 次打到 1,那么在之前的打枪过程中一定是打到 \(S\) 中的猎人,那么概率 为 \((\frac{\sum\limits_{i\in S}w_i}{\sum\limits_{i=1}^nw_i})^i\),枚举的次数可以到无限,所以等比数列求和。最后乘上一个选择 1 的概率,也就是 \(\frac {w_1}{\sum\limits_{i=1}^nw_i}\)
计算 \(f(S)=\sum\limits_{T\in S}g(T)(-1)^{|T|-|S|}\)即可。
发现式子中之和集合中 \(w\) 的和有关,所以可以用背包跑出所有 \(w\) 的和,但是这样会超时,将背包改成分治 FFT 就好了
#include<bits/stdc++.h>
using namespace std;
const int N=1e6+5,P=998244353;
int n,w[N],ret,ans,rr[N];
vector<int>g[N];
int pown(int x,int y)
{
if(!y)
return 1;
int t=pown(x,y>>1);
if(y&1)
return 1LL*t*t%P*x%P;
return 1LL*t*t%P;
}
int read()
{
int s=0;
char ch=getchar();
while(ch<'0'||ch>'9')
ch=getchar();
while(ch>='0'&&ch<='9')
s=s*10+ch-48,ch=getchar();
return s;
}
void ntt(vector<int>&a,int op)
{
for(int i=0;i<a.size();i++)
if(rr[i]<i)
swap(a[i],a[rr[i]]);
for(int md=1;md<a.size();md<<=1)
{
int g=pown(op? 3:332748118,(P-1)/(md<<1));
for(int i=0;i<a.size();i+=md<<1)
{
int pw=1;
for(int j=0;j<md;j++,pw=1LL*pw*g%P)
{
int k=a[i+j+md]*1LL*pw%P;
a[i+j+md]=(a[i+j]-k+P)%P;
(a[i+j]+=k)%=P;
}
}
}
if(!op)
{
int pw=pown(a.size(),P-2);
for(int i=0;i<a.size();i++)
a[i]=1LL*a[i]*pw%P;
}
}
vector<int> merge(int l,int r)
{
if(l==r)
return g[l];
int md=l+r>>1;
vector<int>a=merge(l,md),b=merge(md+1,r);
int s=1<<(int)log2(a.size()+b.size()-2)+1;
for(int i=1;i<s;i++)
rr[i]=rr[i>>1]>>1|(i&1)*s/2;
a.resize(s);
b.resize(s);
ntt(a,1);
ntt(b,1);
for(int i=0;i<s;i++)
a[i]=1LL*a[i]*b[i]%P;
ntt(a,0);
return a;
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
w[i]=read();
g[i].push_back(1);
for(int j=1;j<w[i];j++)
g[i].push_back(0);
g[i].push_back(P-1);
ret+=w[i];
}
vector<int>a=merge(2,n);
for(int i=1;i<ret;i++)
(ans+=1LL*i*pown(ret-i,P-2)%P*a[i]%P)%=P;
if(n%2==0)
ans=1LL*ans*(P-1)%P;
printf("%lld\n",ans*1LL*w[1]%P*pown(ret,P-2)%P);
}