Codeforces960G Bandit Blues 【斯特林数】【FFT】
题目大意:
求满足比之前的任何数小的有A个,比之后的任何数小的有B个的长度为n的排列个数。
题目分析:
首先写出递推式,设s(n,k)表示长度为n的排列,比之前的数小的数有k个。
我们假设新加入的数为1,那么s(n,k)=s(n-1,k-1)+(n-1)*s(n,k)。
这个式子是第一类斯特林数的递推式。
用h(n,a,b)表示满足题目给出条件的排列个数。
得出h(n,a,b)=Σs(k,a-1)*s(n-k-1,b-1)*C(n-1,k)。直观的理解就是将原排列从最高点分成两部分,两部分分别组合然后乘起来。
这样我们发现h(n,a,b)=s(n-1,a+b-2)*C(a+b-2,a-1)。这实际上就是给出一个a+b-2的排列,然后选出其中需要的点放到右边,我们不用考虑多余的点,因为它们的排列已经被计算。
由于无符号第一类斯特林数对应着升幂的系数,构造x(x+1)(x+2)...(x+n-1),它的x^k的系数等于s(n,k)的值,由于最高项系数为1,所以分治FFT。
代码:
1 #include<bits/stdc++.h> 2 #pragma GCC optimize(2) 3 using namespace std; 4 5 const int mod = 998244353; 6 const int gg = 3; 7 8 int n,a,b; 9 10 vector<int> res[205000]; 11 12 int up[405000]; 13 14 int ord[405000]; 15 16 int fast_pow(int now,int pw){ 17 if(pw == 0) return 1; 18 if(pw == 1) return now; 19 int z = fast_pow(now,pw/2); 20 z = (1ll*z*z)%mod; 21 if(pw & 1){z= (1ll*z*now)%mod;} 22 return z; 23 } 24 25 void fft(int now,int len,int f){ 26 for(int i=0;i<len;i++) if(i<ord[i]) swap(res[now][i],res[now][ord[i]]); 27 for(int i=1;i<len;i<<=1){ 28 int wn = fast_pow(gg,(mod-1)/(i<<1)); 29 if(f == -1) wn = fast_pow(wn,mod-2); 30 for(int j=0;j<len;j+=(i<<1)){ 31 for(int k=0,w=1;k<i;k++,w = (1ll*w*wn)%mod){ 32 int x = res[now][j+k],y = (1ll*w*res[now][j+k+i])%mod; 33 res[now][j+k] = (x+y)%mod; 34 res[now][j+k+i] = (x-y+mod)%mod; 35 } 36 } 37 } 38 if(f == -1){ 39 int iv = fast_pow(len,mod-2); 40 for(int i=0;i<len;i++) res[now][i] = (1ll*res[now][i]*iv)%mod; 41 } 42 } 43 44 void multi(int p1,int p2){ 45 int n1 = res[p1].size()-1,n2 = res[p2].size()-1; 46 int len = 1,om = 0; 47 while(len <= (n1+n2+1))len<<=1,om++; 48 for(int i=n1+1;i<len;i++) res[p1].push_back(0); 49 for(int i=n2+1;i<len;i++) res[p2].push_back(0); 50 for(int i=0;i<len;i++) ord[i] = (ord[i>>1]>>1)+((i&1)<<om-1); 51 fft(p1,len,1);fft(p2,len,1); 52 for(int i=0;i<len;i++){ 53 res[p1][i] = (1ll*res[p1][i]*res[p2][i])%mod; 54 if(res[p1][i] < 0) res[p1][i]+=mod; 55 } 56 fft(p1,len,-1); 57 res[p2].clear(); 58 } 59 60 void divide(int l,int r,int now){ 61 if(l == r) {up[now] = l;return;} 62 int mid = (l+r)/2; 63 divide(l,mid,now<<1); 64 divide(mid+1,r,now<<1|1); 65 multi(up[now<<1],up[now<<1|1]); 66 up[now] = up[now<<1]; 67 } 68 69 void work(){ 70 if(a == 0 || b == 0){puts("0");return;} 71 if(n == 1){if(a+b==2)puts("1"); else puts("0"); return;} 72 int c = 1; 73 if(a<b) swap(a,b); 74 if(a-1 > a+b-2) c = 0; 75 for(int i=1;i<=a-1;i++){ 76 c = (1ll*c*(a+b-1-i))%mod; 77 c = (1ll*c*fast_pow(i,mod-2))%mod; 78 } 79 for(int i=1;i<n;i++) res[i].push_back(i-1),res[i].push_back(1); 80 divide(1,n-1,1); 81 c = (1ll*c*res[up[1]][a+b-2])%mod; 82 printf("%d",c); 83 } 84 85 int main(){ 86 scanf("%d%d%d",&n,&a,&b); 87 work(); 88 return 0; 89 }