Examples

loj#2504. 「2018 集训队互测 Day 5」小 H 爱染色 解题报告

题意

给定 \(n\) 个球,初始为白色,进行两次操作,每次选择 \(m\) 个球染黑,给定一个 \(m\) 次多项式 \(F\)(仅给出前 \(m+1\) 项点值),设编号最小的黑球为 \(A\)\(1\leqslant A\leqslant n\)),求 \(F(A-1)\) 的期望,对 \(998244353\) 取模。

\(1\leqslant n<9988244353,1\leqslant m\leqslant 10^6\)

分析

根据期望的线性性,我们把多项式拆成若干个单项式的和,那么我们设目前的次数为 \(c\),我们对于 \((A-1)^c\) 构造出一个组合意义,找到一个最长的白球前缀,在其中任选 \(c\) 个球的方案数,列出式子:(\(f_i\) 表示多项式第 \(i\) 项的系数,\(i\) 枚举在后面那一步实际上选到的白球数量,\(j\) 枚举实际上被染黑的球的数量)

\[ans_c=f_c\sum_{i=0}^m\begin{Bmatrix}c\\i\end{Bmatrix}i!\sum_{j=m}^{2m}{j\choose m}{m\choose m-(j-m)}{n\choose i+j} \]

推一推式子:

\[\sum_{c=0}^mf_c\sum_{i=0}^m\begin{Bmatrix}c\\i\end{Bmatrix}i!\sum_{j=m}^{2m}{j\choose m}{m\choose m-(j-m)}{n\choose i+j}\\=\sum_{i=0}^mi!\sum_{j=m}^{2m}{j\choose m}{m\choose 2m-j}{n\choose i+j}\sum_{c=0}^mf_c\begin{Bmatrix}c\\i\end{Bmatrix} \]

我们用通项公式展开斯特林数:

\[\sum_{i=0}^mi!\sum_{j=m}^{2m}{j\choose m}{m\choose 2m-j}{n\choose i+j}\sum_{c=0}^mf_c\frac{1}{i!}\sum_{k=0}^i{i\choose k}(-1)^{i-k}k^c\\=\sum_{i=0}^m\sum_{j=m}^{2m}{j\choose m}{m\choose 2m-j}{n\choose i+j}\sum_{k=0}^i{i\choose k}(-1)^{i-k}\sum_{c=0}^m f_ck^c\\=\sum_{T=m}^{3m}{n\choose T}\sum_{j=m}^{2m}{j\choose m}{m\choose 2m-j}\sum_{k=0}^{T-j}{T-j\choose k}(-1)^{T-j-k}F(k) \]

其中 \(F(k)\) 表示 \(x=k\) 的时候的点值,容易发现后面枚举 \(k\) 的部分可以使用一次卷积算出来,前面枚举 \(j\) 的部分也可以使用一次卷积算出来,于是复杂度 \(O(n\log n)\)

代码

#include<stdio.h>
#include<vector>
using namespace std;
const int maxn=1<<22,mod=998244353,G=3,invG=(mod+1)/G;
typedef vector<int>poly;
int n,m,ans,lim;
int p[maxn],inv[maxn],fac[maxn],nfac[maxn],F[maxn];
poly f,g;
inline int read(){
	int x=0;
	char c=getchar();
	for(;c<'0'||c>'9';c=getchar());
	for(;c>='0'&&c<='9';c=getchar())
		x=x*10+c-48;
	return x;
}
inline int C(int a,int b){
	return a<b? 0:1ll*fac[a]*nfac[b]%mod*nfac[a-b]%mod;
}
int ksm(int a,int b,int mod){
	int res=1;
	while(b){
		if(b&1)
			res=1ll*res*a%mod;
		a=1ll*a*a%mod,b>>=1;
	}
	return res;
}
int getlen(int n){
	int lim=1,r=0;
	for(;lim<n;lim<<=1,r++);
	for(int i=0;i<lim;i++)
		p[i]=(p[i>>1]>>1)|((i&1)<<(r-1));
	return lim;
}
void NTT(poly &x,int opt){
	x.resize(lim);
	for(int i=0;i<lim;i++)
		if(i<p[i])
			swap(x[i],x[p[i]]);
	for(int len=2,now=1,p=1;len<=lim;len<<=1,now<<=1,p++){
		int w=ksm(opt==1? G:invG,(mod-1)/len,mod);
		for(int i=0;i<lim;i+=len)
			for(int j=0,mul=1;j<now;j++,mul=1ll*mul*w%mod){
				int a=x[i+j],b=1ll*x[now+i+j]*mul%mod;
				x[i+j]=(a+b)%mod,x[now+i+j]=(a-b+mod)%mod;
			}
	}
	if(opt==0)
		for(int i=0;i<lim;i++)
			x[i]=1ll*x[i]*inv[lim]%mod;
}
int main(){
	fac[0]=fac[1]=nfac[0]=nfac[1]=inv[1]=1;
	for(int i=2;i<maxn;i++)
		fac[i]=1ll*fac[i-1]*i%mod,inv[i]=mod-1ll*(mod/i)*inv[mod%i]%mod,nfac[i]=1ll*nfac[i-1]*inv[i]%mod;
	scanf("%d%d",&n,&m);
	for(int i=0;i<=m;i++)
		f.push_back(1ll*read()*nfac[i]%mod),g.push_back((i&1)? (mod-nfac[i]):nfac[i]);
	lim=getlen(2*m+2),NTT(f,1),NTT(g,1);
	for(int i=0;i<lim;i++)
		f[i]=1ll*f[i]*g[i]%mod;
	NTT(f,0),f.resize(m+1),g.clear();
	for(int i=0;i<=m;i++)
		f[i]=1ll*f[i]*fac[i]%mod,g.push_back(1ll*C(m+i,m)*C(m,m+m-(m+i))%mod);
	NTT(f,1),NTT(g,1);
	for(int i=0;i<lim;i++)
		f[i]=1ll*f[i]*g[i]%mod;
	NTT(f,0);
	for(int i=0,C=1;i<=3*m;i++,C=1ll*C*(n-i+1)%mod*inv[i]%mod)
		if(i>=m)
			ans=(ans+1ll*C*f[i-m])%mod;
	printf("%d\n",ans);
	return 0;
}
posted @ 2021-10-29 20:50  xiaoziyao  阅读(65)  评论(0编辑  收藏  举报