[ABC331G] Collect Them All 题解
\(\textbf{Statement.}\)
有 \(M\) 种颜色,用 \(1\sim M\) 编号,每次抽奖抽中第 \(i\) 种颜色的概率为 \(\frac{c_i}{N}\),其中 \(\sum c_i=N\),求抽中每种颜色至少一次的期望次数。
\(1\le M\le N\le 2\times 10^5\)。
\(\textbf{Solution.}\)
发现直接求不好做,考虑容斥。记 \(f(n)\) 表示在前 \(n\) 天抽中每种颜色至少一次的概率,\(g(S)=\frac{\sum_{i\in S}c_i}{N}\)。枚举第 \(n\) 天时抽过的颜色集合作容斥,易得:
\[f(n)=\sum_{S\subset \{1,2,\dots,M\}}(-1)^{M-|S|}g(S)^n
\]
那么在第 \(n+1\) 天还没收集完的概率就是 \(1-f(n)\),所以答案就转化为(注意,\(\subset\) 是真包含于,不能相等,\(\subseteq\) 是包含于,可以相等):
\[\begin{aligned}
&\sum_{n=0}^{\infty}1-f(n)\\
&=\sum_{n=0}^{\infty}\left(1-\sum_{S\subset \{1,2,\dots,M\}}(-1)^{M-|S|}g(S)^n\right)\\
&=\sum_{n=0}^{\infty}\sum_{S\subseteq \{1,2,\dots,M\}}(-1)^{M-|S|+1}g(S)^n\\
&=\sum_{S\subseteq \{1,2,\dots,M\}}(-1)^{M-|S|+1}\sum_{n=0}^{\infty}g(S)^n\\
&=\sum_{S\subseteq \{1,2,\dots,M\}}(-1)^{M-|S|+1}\frac{1}{1-g(S)}\\
&=\sum_{k=0}^{N-1}\frac{N}{N-k}\sum_{g(S)=\frac k N}(-1)^{M-|S|+1}
\end{aligned}
\]
然后就可以用生成函数做了,跟 P9773 相同的做法,答案就是:
\[\sum_{k=0}^{N-1}[x^k](-1)^{M+1}\prod_{i=1}^{M}(1-x^{c_i})
\]
考虑怎么求出 \(\prod_{i=1}^{M}(1-x^{c_i})\),可以先将这些次数为 \(c_i\) 的多项式加入一个堆,每次弹出两个次数最小的,用 NTT 暴力合并,时间复杂度 \(O((N+M)\log M\log N)\)。
参考代码:
#include<bits/stdc++.h>
#define ll long long
#define int long long
#define mxn 500003
#define md 998244353
#define rep(i,a,b) for(int i=a;i<=b;++i)
#define rept(i,a,b) for(int i=a;i<b;++i)
#define drep(i,a,b) for(int i=a;i>=b;--i)
using namespace std;
int power(int x,int y){
int ans=1;
for(;y;y>>=1){
if(y&1)ans=(ll)ans*x%md;
x=(ll)x*x%md;
}
return ans;
}
struct node{
vector<int>a;
};
inline bool operator<(node x,node y){
return x.a.size()>y.a.size();
}
int n,m,a[mxn],f[mxn],f1[mxn],rev[mxn];
ll ans;
priority_queue<node>q;
void ntt(int *a,int n,int flag){
rept(i,0,n)if(i<rev[i])swap(a[i],a[rev[i]]);
for(int h=1;h<n;h<<=1){
int x,y,s=power(3,499122176/h);
for(int j=0;j<n;j+=h<<1){
int w=1;
for(int k=j;k<j+h;++k){
x=a[k],y=w*a[k+h]%md;
a[k]=(x+y)%md;
a[k+h]=(x-y+md)%md;
w=w*s%md;
}
}
}
if(flag==-1){
int p=power(n,md-2);
reverse(a+1,a+n);
rept(i,0,n)a[i]=a[i]*p%md;
}
}
signed main(){
scanf("%lld%lld",&n,&m);
rep(i,1,m){
scanf("%lld",&a[i]);
node s;
s.a.resize(a[i]+1);
s.a[0]=1,s.a[a[i]]=-1;
q.push(s);
}
while(q.size()>1){
node a=q.top();q.pop();
node b=q.top();q.pop();
node c;c.a.resize(a.a.size()+b.a.size()-1);
int k,s;
for(k=0,s=1;s<c.a.size();s<<=1,++k);
rept(i,0,s)rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));
rept(i,0,a.a.size())f[i]=a.a[i];
rept(i,a.a.size(),s)f[i]=0;
rept(i,0,b.a.size())f1[i]=b.a[i];
rept(i,b.a.size(),s)f1[i]=0;
ntt(f,s,1);ntt(f1,s,1);
rept(i,0,s)f[i]=f[i]*f1[i]%md;
ntt(f,s,-1);
rept(i,0,c.a.size())c.a[i]=f[i];
q.push(c);
}
node s=q.top();
rept(i,0,n){
ans=(ans+(m&1?1:-1)*s.a[i]%md*n%md*power(n-i,md-2))%md;
}
cout<<(ans+md)%md;
return 0;
}