CF960G Bandit Blues

前置芝士

第一类斯特林数递推式

\[f_{i,j}=f_{i-1,j-1}+(i-1)\times f_{i-1,j} \]

第一类斯特林数生成函数

有符号:\(\prod_{i=0}^{n-1}(x-i)\)
无符号:\(\prod_{i=0}^{n-1}(x+i)\)
第i项的系数就是\(S_n^i\),然后这个东西可以分治+FFT在\(O(n\log^2n)\)的时间内求出

题意

求满足a个数前面,b个数后面没有比它们大的数的排列个数

思路

朴素思路

考虑一个递推
\(f_{i,j}\)表示i个数的排列,其中j个数的前面没有比它大的数
则选择把最小的数放在哪里
最小的数放在首位的时候,前面没有比它更大的数,方案数显然是\(f_{i-1,j-1}\)
最小的数放在其他位置的时候,任意位置前面都有更大的数,方案数是\((i-1)\times f_{i-1,j}\)
所以f的递推式为

\[f_{i,j}=f_{i-1,j-1}+(i-1)\times f_{i-1,j} \]

发现恰好是第一类斯特林数的递推公式
因为对称性,i个数的排列j个数后面没有比它更大的数的方案数也是\(f_{i,j}\)
因为题目的限制,最大的数n前面一定有a-1个满足条件的数,后面一定有b-1个满足条件的数
所以可以枚举n的位置,同时利用组合数计算哪a-1个数被放在前面得到答案
有:

\[ans=\sum_{i=1}^n S_{i-1}^{a-1} S_{n-i}^{b-1} C_{n-1}^{a-1} \]

优化做法

朴素的做法显然是过不了的
这样考虑,有k个前面没有比它大的数,它们的位置是\(\left\{P_1,P_2,\dots,P_k \right\}\)
这些位置把整个序列分成a+b-2块,每块都是形如\(\left[ P_i,P_{i+1} \right]\)的区间,显然我们可以把b-1块翻转后放到后面使他们满足条件
所以式子变成了

\[ans=S_{n-1}^{a+b-2}C_{a+b-2}^{a-1} \]

分治+FFT快速处理斯特林数即可

代码

#include <cstdio>
#include <cstring>
#include <algorithm>
#define int long long
using namespace std;
const int MOD = 998244353,G=3,invG=332748118;
int midt[18][200100],n,a,b,jc[100100],inv[100100];
int pow(int a,int b){
    int ans=1;
    while(b){
        if(b&1)
            ans=(1LL*ans*a)%MOD;
        a=(1LL*a*a)%MOD;
        b>>=1;
    }
    return ans;
}
void init(void){
    jc[0]=1,inv[0]=1;
    for(int i=1;i<=max(a+b-2,a-1);i++)
        jc[i]=jc[i-1]*i%MOD,inv[i]=pow(jc[i],MOD-2);
}
int C(int n,int m){
    return jc[n]*inv[m]%MOD*inv[n-m]%MOD;
}
void FFT(int *a,int opt,int n){
    int lim=0;
    while((1<<lim)<n)
        lim++;
    for(int i=0;i<n;i++){
        int t=0;
        for(int j=0;j<lim;j++)
            if((i>>j)&1)
                t|=(1<<(lim-j-1));
        if(i<t)
            swap(a[t],a[i]);
    }
    for(int i=2;i<=n;i<<=1){
        int len=(i/2);
        int tmp=pow((opt)?G:invG,(MOD-1)/i);
        for(int j=0;j<n;j+=i){
            int arr=1;
            for(int k=j;k<j+len;k++){
                int t=arr*a[k+len];
                a[k+len]=((a[k]-t)%MOD+MOD)%MOD;
                a[k]=(a[k]+t)%MOD;
                arr=(arr*tmp)%MOD;
            }
        }
    }
    if(opt==0){
        int invn=pow(n,MOD-2);
        for(int i=0;i<n;i++)
            a[i]=a[i]*invn%MOD;
    }
}
void solve(int l,int r,int d){
    if(r==l){
        midt[d][0]=l;
        midt[d][1]=1;
        return;
    }
    int mid=(l+r)>>1,nt=1;
    while((nt)<=(r-l+1))
        nt<<=1;
    solve(l,mid,d+1);
    for(int i=0;i<=mid-l+1;i++)
        midt[d][i]=midt[d+1][i];
    solve(mid+1,r,d+1);
    for(int i=mid-l+2;i<nt;i++)
        midt[d][i]=0;
    for(int i=r-mid+1;i<nt;i++)
        midt[d+1][i]=0;
    FFT(midt[d],1,nt);
    FFT(midt[d+1],1,nt);
    for(int i=0;i<nt;i++)
        midt[d][i]=midt[d+1][i]*midt[d][i]%MOD;
    FFT(midt[d],0,nt);
}
signed main(){
    scanf("%lld %lld %lld",&n,&a,&b);
    if(!a||!b||(a+b-2>n-1)){
        printf("%d\n",0);
        return 0;
    }
    if(n==1){
        printf("1\n");
        return 0;
    }
    init();
    solve(0,n-2,0);
    printf("%lld\n",midt[0][(a+b-2)]*C(a+b-2,a-1)%MOD);
    return 0;
}
posted @ 2019-02-25 19:11  dreagonm  阅读(178)  评论(1编辑  收藏  举报