CF960G Bandit Blues
前置芝士
第一类斯特林数递推式
\[f_{i,j}=f_{i-1,j-1}+(i-1)\times f_{i-1,j}
\]
第一类斯特林数生成函数
有符号:\(\prod_{i=0}^{n-1}(x-i)\)
无符号:\(\prod_{i=0}^{n-1}(x+i)\)
第i项的系数就是\(S_n^i\),然后这个东西可以分治+FFT在\(O(n\log^2n)\)的时间内求出
题意
求满足a个数前面,b个数后面没有比它们大的数的排列个数
思路
朴素思路
考虑一个递推
设\(f_{i,j}\)表示i个数的排列,其中j个数的前面没有比它大的数
则选择把最小的数放在哪里
最小的数放在首位的时候,前面没有比它更大的数,方案数显然是\(f_{i-1,j-1}\)
最小的数放在其他位置的时候,任意位置前面都有更大的数,方案数是\((i-1)\times f_{i-1,j}\)
所以f的递推式为
\[f_{i,j}=f_{i-1,j-1}+(i-1)\times f_{i-1,j}
\]
发现恰好是第一类斯特林数的递推公式
因为对称性,i个数的排列j个数后面没有比它更大的数的方案数也是\(f_{i,j}\)
因为题目的限制,最大的数n前面一定有a-1个满足条件的数,后面一定有b-1个满足条件的数
所以可以枚举n的位置,同时利用组合数计算哪a-1个数被放在前面得到答案
有:
\[ans=\sum_{i=1}^n S_{i-1}^{a-1} S_{n-i}^{b-1} C_{n-1}^{a-1}
\]
优化做法
朴素的做法显然是过不了的
这样考虑,有k个前面没有比它大的数,它们的位置是\(\left\{P_1,P_2,\dots,P_k \right\}\)
这些位置把整个序列分成a+b-2块,每块都是形如\(\left[ P_i,P_{i+1} \right]\)的区间,显然我们可以把b-1块翻转后放到后面使他们满足条件
所以式子变成了
\[ans=S_{n-1}^{a+b-2}C_{a+b-2}^{a-1}
\]
分治+FFT快速处理斯特林数即可
代码
#include <cstdio>
#include <cstring>
#include <algorithm>
#define int long long
using namespace std;
const int MOD = 998244353,G=3,invG=332748118;
int midt[18][200100],n,a,b,jc[100100],inv[100100];
int pow(int a,int b){
int ans=1;
while(b){
if(b&1)
ans=(1LL*ans*a)%MOD;
a=(1LL*a*a)%MOD;
b>>=1;
}
return ans;
}
void init(void){
jc[0]=1,inv[0]=1;
for(int i=1;i<=max(a+b-2,a-1);i++)
jc[i]=jc[i-1]*i%MOD,inv[i]=pow(jc[i],MOD-2);
}
int C(int n,int m){
return jc[n]*inv[m]%MOD*inv[n-m]%MOD;
}
void FFT(int *a,int opt,int n){
int lim=0;
while((1<<lim)<n)
lim++;
for(int i=0;i<n;i++){
int t=0;
for(int j=0;j<lim;j++)
if((i>>j)&1)
t|=(1<<(lim-j-1));
if(i<t)
swap(a[t],a[i]);
}
for(int i=2;i<=n;i<<=1){
int len=(i/2);
int tmp=pow((opt)?G:invG,(MOD-1)/i);
for(int j=0;j<n;j+=i){
int arr=1;
for(int k=j;k<j+len;k++){
int t=arr*a[k+len];
a[k+len]=((a[k]-t)%MOD+MOD)%MOD;
a[k]=(a[k]+t)%MOD;
arr=(arr*tmp)%MOD;
}
}
}
if(opt==0){
int invn=pow(n,MOD-2);
for(int i=0;i<n;i++)
a[i]=a[i]*invn%MOD;
}
}
void solve(int l,int r,int d){
if(r==l){
midt[d][0]=l;
midt[d][1]=1;
return;
}
int mid=(l+r)>>1,nt=1;
while((nt)<=(r-l+1))
nt<<=1;
solve(l,mid,d+1);
for(int i=0;i<=mid-l+1;i++)
midt[d][i]=midt[d+1][i];
solve(mid+1,r,d+1);
for(int i=mid-l+2;i<nt;i++)
midt[d][i]=0;
for(int i=r-mid+1;i<nt;i++)
midt[d+1][i]=0;
FFT(midt[d],1,nt);
FFT(midt[d+1],1,nt);
for(int i=0;i<nt;i++)
midt[d][i]=midt[d+1][i]*midt[d][i]%MOD;
FFT(midt[d],0,nt);
}
signed main(){
scanf("%lld %lld %lld",&n,&a,&b);
if(!a||!b||(a+b-2>n-1)){
printf("%d\n",0);
return 0;
}
if(n==1){
printf("1\n");
return 0;
}
init();
solve(0,n-2,0);
printf("%lld\n",midt[0][(a+b-2)]*C(a+b-2,a-1)%MOD);
return 0;
}