【20省选十联测day2】选拔赛

【20省选十联测day2】选拔赛

\(f_n\) 表示剩下 \(n\) 个人时,期望多少轮结束。有如下递推式:

\[f_n=1+(p^n+(1-p)^n)f_n+\sum_{i=1}^{n-1} f_{n-i} \binom{n}{i} (1-p)^i p^{n-i}\\ f_n=(\sum_{i=1}^{n-1} f_{n-i} \binom{n}{i} (1-p)^i p^{n-i}+1)\div(1-p^n-(1-p)^n)\\ f_n=((\sum_{i=1}^{n-1} n! f_{n-i}p^{n-i} \frac{1}{(n-i)!} \frac{1}{i!}(1-p)^i )+1)\div(1-p^n-(1-p)^n)\]

显然可以 \(n^2\) 求解。

观察到 sum 的那一坨是一个卷积,标准卷积。但是由于 \(f\) 是未知的,所以考虑分治 FFT 求解,时间复杂度 \(O(n\log^2 n)\)。也可以借助生成函数转化成多项式求逆问题。

AC code

#include<bits/stdc++.h>
//#define LOCAL
#define sf scanf
#define pf printf
#define rep(x,y,z) for(int x=y;x<=z;x++)
#define per(x,y,z) for(int x=y;x>=z;x--)
using namespace std;
typedef long long ll;
const int N=2e5+7,mod=998244353,G=3,invG=332748118;
int n;
ll p,x,y;
ll ksm(ll a,ll b=mod-2) {
	ll s=1;
	while(b) {
		if(b&1) s=s*a%mod;
		a=a*a%mod;
		b>>=1;
	}
	return s;
}
ll f[N],g[N],a[N],b[N];
int re[N];
ll add(ll a,ll b) {return a+b>=mod?a+b-mod:a+b; }
void getre(int len) {
	rep(i,1,(1<<len)-1) {
		re[i]=(re[i>>1]>>1)|((i&1)<<(len-1));
	}
}
void NTT(ll *a,int len,int type) {
	rep(i,0,(1<<len)-1) {
		if(i<re[i]) swap(a[i],a[re[i]]);
	}
	for(int k=1;k<(1<<len);k<<=1) {
		ll wn=ksm(type==1?G:invG,(mod-1)/(k<<1));
		for(int r=k<<1,j=0;j<(1<<len);j+=r) {
			ll w=1;
			for(int i=0;i<k;i++,w=w*wn%mod) {
				ll x=a[j+i],y=w*a[j+k+i]%mod;
				a[j+i]=add(x,y);
				a[j+k+i]=add(x,mod-y);
			}
		}
	}
}
ll pnn[N],pn[N],jc[N],in[N];
void solve(int l,int r,int len) {
	if(len==0) {
		if(l>1) f[l]=add(f[l],1ll)*ksm(add(1,mod-add(pn[l],pnn[l])))%mod;
		return;
	}
	if(l>n) return;
	int mid=(l+r)>>1;
	int s=r-l;
	solve(l,mid,len-1);
	getre(len);
	rep(i,0,s/2-1) a[i]=f[i+l]*pn[i+l]%mod*in[i+l]%mod;
	memset(a+s/2,0,sizeof(ll)*s/2);
	memcpy(b,g,sizeof(ll)*s);
	NTT(a,len,1),NTT(b,len,1);
	rep(i,0,s-1) a[i]=a[i]*b[i]%mod;
	NTT(a,len,-1);
	ll inv=ksm(s);
	rep(i,0,s-1) a[i]=a[i]*inv%mod;
	rep(i,s/2,s-1) f[i+l]=add(f[i+l],a[i]*jc[i+l]%mod);
	solve(mid,r,len-1);
}
int main(){
	#ifdef LOCAL
	freopen("in.txt","r",stdin);
	freopen("my.out","w",stdout);
	#endif
	sf("%d%lld%lld",&n,&x,&y);
	p=x*ksm(y)%mod;
	pn[0]=1;jc[0]=1;pnn[0]=1;
	rep(i,1,n<<1) pn[i]=pn[i-1]*p%mod,pnn[i]=pnn[i-1]*add(1,mod-p)%mod,jc[i]=jc[i-1]*i%mod;
	in[n<<1]=ksm(jc[n<<1]);
	per(i,(n<<1)-1,0) in[i]=in[i+1]*(i+1)%mod;
	rep(i,0,n) g[i]=in[i]*pnn[i]%mod;
	int len=0;
	while((1<<len)<n) len++;
	solve(1,(1<<len)+1,len);
	pf("%lld\n",f[n]);
}
posted @ 2024-09-22 21:26  liyixin  阅读(4)  评论(0编辑  收藏  举报