[LOJ2541] [PKUWC2018] 猎人杀
题目链接
LOJ:https://loj.ac/problem/2541
Solution
很巧妙的思路。
注意到运行的过程中概率的分母在不停的变化,这样会让我们很不好算,我们考虑这样转化:假设所有人都活着,然后随机选一个人,如果此人已死那就重新选一次。
假设当前活着的人集合为\(T\),那么射中第\(i\)个人的概率就是:
\[\sum_{i=0}^{\infty}\left(\frac{s_{all}-s_T}{s_{all}}\right)^i\frac{w_i}{s_{all}}=\frac{w_i}{s_T}
\]
其中\(s_p\)表示\(p\)集合的\(w\)总和,可以发现这样选的概率和原来是一样的。
我们考虑容斥,设\(f(T)\)表示至少\(T\)集合的人比\(1\)号后死,用一个很简单的容斥可以得到:
\[ans=\sum_{T}(-1)^{|T|}f(T)
\]
那么大力算可以得到\(f\):
\[\begin{align}f(T)&=\sum_{i=0}^{\infty}\left(\frac{s_{all}-s_T-w_1}{s_{all}}\right)^i\cdot \frac{w_1}{s_{all}}\\&=\frac{w_1}{w_1+s_T}\end{align}
\]
答案就是:
\[ans=\sum_T(-1)^{|T|}\frac{w_1}{w_1+s_T}
\]
注意到\(s\)至多只有\(1e5\),我们可以背包算出每个\(s_T\)出现了多少次,背包的时候顺便把容斥系数带上。
这样做是\(O(ns)\)的,显然\(T\)掉了。
但是我们可以用生成函数优化这个东西,直接就是:
\[\prod_{i=2}^{n}(1-x^{w_i})
\]
然后分治\(FFT\)优化就好了,复杂度\(O(n\log ^2 n)\)。
#include<bits/stdc++.h>
using namespace std;
void read(int &x) {
x=0;int f=1;char ch=getchar();
for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
}
void print(int x) {
if(x<0) putchar('-'),x=-x;
if(!x) return ;print(x/10),putchar(x%10+48);
}
void write(int x) {if(!x) putchar('0');else print(x);putchar('\n');}
#define lf double
#define ll long long
#define pii pair<int,int >
#define vec vector<int >
#define pb push_back
#define mp make_pair
#define fr first
#define sc second
#define FOR(i,l,r) for(int i=l,i##_r=r;i<=i##_r;i++)
const int maxn = 4e5+10;
const int inf = 1e9;
const lf eps = 1e-8;
const int mod = 998244353;
int w[maxn],pos[maxn],N,bit,f[maxn],a[maxn],s[maxn],n,mxn;
int add(int x,int y) {return x+y>=mod?x+y-mod:x+y;}
int del(int x,int y) {return x-y<0?x-y+mod:x-y;}
int mul(int x,int y) {return 1ll*x*y-1ll*x*y/mod*mod;}
int qpow(int a,int x) {
int res=1;
for(;x;x>>=1,a=mul(a,a)) if(x&1) res=mul(res,a);
return res;
}
void prepare(int t) {
for(N=1,bit=0;N<=t;N<<=1,bit++);mxn=N;w[0]=1,w[1]=qpow(3,(mod-1)/mxn);
for(int i=2;i<=N;i++) w[i]=mul(w[i-1],w[1]);
}
void ntt_get(int t) {
for(N=1,bit=0;N<=t;N<<=1,bit++);
for(int i=1;i<N;i++) pos[i]=pos[i>>1]>>1|((i&1)<<(bit-1));
}
void ntt(int *r,int op) {
for(int i=1;i<N;i++) if(pos[i]>i) swap(r[i],r[pos[i]]);
for(int i=1,d=mxn>>1;i<N;i<<=1,d>>=1)
for(int j=0;j<N;j+=i<<1)
for(int k=0;k<i;k++) {
int x=r[j+k],y=mul(r[i+j+k],w[k*d]);
r[j+k]=add(x,y),r[i+j+k]=del(x,y);
}
if(op==-1) {
reverse(r+1,r+N);int d=qpow(N,mod-2);
for(int i=0;i<N;i++) r[i]=mul(r[i],d);
}
}
int get(int lt,int rt) {
int l=lt,r=rt,mid,ans=lt;
while(l<=r) {
mid=(l+r)>>1;
if(s[rt]-s[mid]>=s[mid]-s[lt-1]) l=mid+1,ans=mid;
else r=mid-1;
}return ans;
}
void solve(int l,int r,int *t) {
if(l>r) return ;
if(l==r) {t[0]=1,t[a[l]]=mod-1;return ;}
int d=1<<((int)ceil(log2(s[r]-s[l-1]))+1);
int *sl=new int [d+10],*sr=new int [d+10],mid=get(l,r);
for(int i=0;i<=d+5;i++) sl[i]=sr[i]=0;
solve(l,mid,sl),solve(mid+1,r,sr);
ntt_get(d>>1);ntt(sl,1),ntt(sr,1);
for(int i=0;i<N;i++) t[i]=mul(sl[i],sr[i]);
ntt(t,-1);delete sl;delete sr;
}
int main() {
read(n);for(int i=1;i<=n;i++) read(a[i]),s[i]=s[i-1]+a[i];
prepare(s[n]<<1);solve(2,n,f);int ans=0;
for(int i=0;i<=s[n];i++) ans=add(ans,mul(qpow(a[1]+i,mod-2),f[i]));
write(mul(ans,a[1]));
return 0;
}