【20省选十联测day2】选拔赛
【20省选十联测day2】选拔赛
设 \(f_n\) 表示剩下 \(n\) 个人时,期望多少轮结束。有如下递推式:
\[f_n=1+(p^n+(1-p)^n)f_n+\sum_{i=1}^{n-1} f_{n-i} \binom{n}{i} (1-p)^i p^{n-i}\\
f_n=(\sum_{i=1}^{n-1} f_{n-i} \binom{n}{i} (1-p)^i p^{n-i}+1)\div(1-p^n-(1-p)^n)\\
f_n=((\sum_{i=1}^{n-1} n! f_{n-i}p^{n-i} \frac{1}{(n-i)!} \frac{1}{i!}(1-p)^i )+1)\div(1-p^n-(1-p)^n)\]
显然可以 \(n^2\) 求解。
观察到 sum 的那一坨是一个卷积,标准卷积。但是由于 \(f\) 是未知的,所以考虑分治 FFT 求解,时间复杂度 \(O(n\log^2 n)\)。也可以借助生成函数转化成多项式求逆问题。
AC code
#include<bits/stdc++.h>
//#define LOCAL
#define sf scanf
#define pf printf
#define rep(x,y,z) for(int x=y;x<=z;x++)
#define per(x,y,z) for(int x=y;x>=z;x--)
using namespace std;
typedef long long ll;
const int N=2e5+7,mod=998244353,G=3,invG=332748118;
int n;
ll p,x,y;
ll ksm(ll a,ll b=mod-2) {
ll s=1;
while(b) {
if(b&1) s=s*a%mod;
a=a*a%mod;
b>>=1;
}
return s;
}
ll f[N],g[N],a[N],b[N];
int re[N];
ll add(ll a,ll b) {return a+b>=mod?a+b-mod:a+b; }
void getre(int len) {
rep(i,1,(1<<len)-1) {
re[i]=(re[i>>1]>>1)|((i&1)<<(len-1));
}
}
void NTT(ll *a,int len,int type) {
rep(i,0,(1<<len)-1) {
if(i<re[i]) swap(a[i],a[re[i]]);
}
for(int k=1;k<(1<<len);k<<=1) {
ll wn=ksm(type==1?G:invG,(mod-1)/(k<<1));
for(int r=k<<1,j=0;j<(1<<len);j+=r) {
ll w=1;
for(int i=0;i<k;i++,w=w*wn%mod) {
ll x=a[j+i],y=w*a[j+k+i]%mod;
a[j+i]=add(x,y);
a[j+k+i]=add(x,mod-y);
}
}
}
}
ll pnn[N],pn[N],jc[N],in[N];
void solve(int l,int r,int len) {
if(len==0) {
if(l>1) f[l]=add(f[l],1ll)*ksm(add(1,mod-add(pn[l],pnn[l])))%mod;
return;
}
if(l>n) return;
int mid=(l+r)>>1;
int s=r-l;
solve(l,mid,len-1);
getre(len);
rep(i,0,s/2-1) a[i]=f[i+l]*pn[i+l]%mod*in[i+l]%mod;
memset(a+s/2,0,sizeof(ll)*s/2);
memcpy(b,g,sizeof(ll)*s);
NTT(a,len,1),NTT(b,len,1);
rep(i,0,s-1) a[i]=a[i]*b[i]%mod;
NTT(a,len,-1);
ll inv=ksm(s);
rep(i,0,s-1) a[i]=a[i]*inv%mod;
rep(i,s/2,s-1) f[i+l]=add(f[i+l],a[i]*jc[i+l]%mod);
solve(mid,r,len-1);
}
int main(){
#ifdef LOCAL
freopen("in.txt","r",stdin);
freopen("my.out","w",stdout);
#endif
sf("%d%lld%lld",&n,&x,&y);
p=x*ksm(y)%mod;
pn[0]=1;jc[0]=1;pnn[0]=1;
rep(i,1,n<<1) pn[i]=pn[i-1]*p%mod,pnn[i]=pnn[i-1]*add(1,mod-p)%mod,jc[i]=jc[i-1]*i%mod;
in[n<<1]=ksm(jc[n<<1]);
per(i,(n<<1)-1,0) in[i]=in[i+1]*(i+1)%mod;
rep(i,0,n) g[i]=in[i]*pnn[i]%mod;
int len=0;
while((1<<len)<n) len++;
solve(1,(1<<len)+1,len);
pf("%lld\n",f[n]);
}
本文来自博客园,作者:liyixin,转载请注明原文链接:https://www.cnblogs.com/liyixin0514/p/18425895