loj#2504. 「2018 集训队互测 Day 5」小 H 爱染色 解题报告
题意
给定 \(n\) 个球,初始为白色,进行两次操作,每次选择 \(m\) 个球染黑,给定一个 \(m\) 次多项式 \(F\)(仅给出前 \(m+1\) 项点值),设编号最小的黑球为 \(A\)(\(1\leqslant A\leqslant n\)),求 \(F(A-1)\) 的期望,对 \(998244353\) 取模。
\(1\leqslant n<9988244353,1\leqslant m\leqslant 10^6\)
分析
根据期望的线性性,我们把多项式拆成若干个单项式的和,那么我们设目前的次数为 \(c\),我们对于 \((A-1)^c\) 构造出一个组合意义,找到一个最长的白球前缀,在其中任选 \(c\) 个球的方案数,列出式子:(\(f_i\) 表示多项式第 \(i\) 项的系数,\(i\) 枚举在后面那一步实际上选到的白球数量,\(j\) 枚举实际上被染黑的球的数量)
\[ans_c=f_c\sum_{i=0}^m\begin{Bmatrix}c\\i\end{Bmatrix}i!\sum_{j=m}^{2m}{j\choose m}{m\choose m-(j-m)}{n\choose i+j}
\]
推一推式子:
\[\sum_{c=0}^mf_c\sum_{i=0}^m\begin{Bmatrix}c\\i\end{Bmatrix}i!\sum_{j=m}^{2m}{j\choose m}{m\choose m-(j-m)}{n\choose i+j}\\=\sum_{i=0}^mi!\sum_{j=m}^{2m}{j\choose m}{m\choose 2m-j}{n\choose i+j}\sum_{c=0}^mf_c\begin{Bmatrix}c\\i\end{Bmatrix}
\]
我们用通项公式展开斯特林数:
\[\sum_{i=0}^mi!\sum_{j=m}^{2m}{j\choose m}{m\choose 2m-j}{n\choose i+j}\sum_{c=0}^mf_c\frac{1}{i!}\sum_{k=0}^i{i\choose k}(-1)^{i-k}k^c\\=\sum_{i=0}^m\sum_{j=m}^{2m}{j\choose m}{m\choose 2m-j}{n\choose i+j}\sum_{k=0}^i{i\choose k}(-1)^{i-k}\sum_{c=0}^m f_ck^c\\=\sum_{T=m}^{3m}{n\choose T}\sum_{j=m}^{2m}{j\choose m}{m\choose 2m-j}\sum_{k=0}^{T-j}{T-j\choose k}(-1)^{T-j-k}F(k)
\]
其中 \(F(k)\) 表示 \(x=k\) 的时候的点值,容易发现后面枚举 \(k\) 的部分可以使用一次卷积算出来,前面枚举 \(j\) 的部分也可以使用一次卷积算出来,于是复杂度 \(O(n\log n)\)。
代码
#include<stdio.h>
#include<vector>
using namespace std;
const int maxn=1<<22,mod=998244353,G=3,invG=(mod+1)/G;
typedef vector<int>poly;
int n,m,ans,lim;
int p[maxn],inv[maxn],fac[maxn],nfac[maxn],F[maxn];
poly f,g;
inline int read(){
int x=0;
char c=getchar();
for(;c<'0'||c>'9';c=getchar());
for(;c>='0'&&c<='9';c=getchar())
x=x*10+c-48;
return x;
}
inline int C(int a,int b){
return a<b? 0:1ll*fac[a]*nfac[b]%mod*nfac[a-b]%mod;
}
int ksm(int a,int b,int mod){
int res=1;
while(b){
if(b&1)
res=1ll*res*a%mod;
a=1ll*a*a%mod,b>>=1;
}
return res;
}
int getlen(int n){
int lim=1,r=0;
for(;lim<n;lim<<=1,r++);
for(int i=0;i<lim;i++)
p[i]=(p[i>>1]>>1)|((i&1)<<(r-1));
return lim;
}
void NTT(poly &x,int opt){
x.resize(lim);
for(int i=0;i<lim;i++)
if(i<p[i])
swap(x[i],x[p[i]]);
for(int len=2,now=1,p=1;len<=lim;len<<=1,now<<=1,p++){
int w=ksm(opt==1? G:invG,(mod-1)/len,mod);
for(int i=0;i<lim;i+=len)
for(int j=0,mul=1;j<now;j++,mul=1ll*mul*w%mod){
int a=x[i+j],b=1ll*x[now+i+j]*mul%mod;
x[i+j]=(a+b)%mod,x[now+i+j]=(a-b+mod)%mod;
}
}
if(opt==0)
for(int i=0;i<lim;i++)
x[i]=1ll*x[i]*inv[lim]%mod;
}
int main(){
fac[0]=fac[1]=nfac[0]=nfac[1]=inv[1]=1;
for(int i=2;i<maxn;i++)
fac[i]=1ll*fac[i-1]*i%mod,inv[i]=mod-1ll*(mod/i)*inv[mod%i]%mod,nfac[i]=1ll*nfac[i-1]*inv[i]%mod;
scanf("%d%d",&n,&m);
for(int i=0;i<=m;i++)
f.push_back(1ll*read()*nfac[i]%mod),g.push_back((i&1)? (mod-nfac[i]):nfac[i]);
lim=getlen(2*m+2),NTT(f,1),NTT(g,1);
for(int i=0;i<lim;i++)
f[i]=1ll*f[i]*g[i]%mod;
NTT(f,0),f.resize(m+1),g.clear();
for(int i=0;i<=m;i++)
f[i]=1ll*f[i]*fac[i]%mod,g.push_back(1ll*C(m+i,m)*C(m,m+m-(m+i))%mod);
NTT(f,1),NTT(g,1);
for(int i=0;i<lim;i++)
f[i]=1ll*f[i]*g[i]%mod;
NTT(f,0);
for(int i=0,C=1;i<=3*m;i++,C=1ll*C*(n-i+1)%mod*inv[i]%mod)
if(i>=m)
ans=(ans+1ll*C*f[i-m])%mod;
printf("%d\n",ans);
return 0;
}