Codeforces960G Bandit Blues 【斯特林数】【FFT】

题目大意:

  求满足比之前的任何数小的有A个,比之后的任何数小的有B个的长度为n的排列个数。

题目分析:

  首先写出递推式,设s(n,k)表示长度为n的排列,比之前的数小的数有k个。

  我们假设新加入的数为1,那么s(n,k)=s(n-1,k-1)+(n-1)*s(n,k)。

  这个式子是第一类斯特林数的递推式。

  用h(n,a,b)表示满足题目给出条件的排列个数。

  得出h(n,a,b)=Σs(k,a-1)*s(n-k-1,b-1)*C(n-1,k)。直观的理解就是将原排列从最高点分成两部分,两部分分别组合然后乘起来。

  这样我们发现h(n,a,b)=s(n-1,a+b-2)*C(a+b-2,a-1)。这实际上就是给出一个a+b-2的排列,然后选出其中需要的点放到右边,我们不用考虑多余的点,因为它们的排列已经被计算。

  由于无符号第一类斯特林数对应着升幂的系数,构造x(x+1)(x+2)...(x+n-1),它的x^k的系数等于s(n,k)的值,由于最高项系数为1,所以分治FFT。

代码:

  

 1 #include<bits/stdc++.h>
 2 #pragma GCC optimize(2)
 3 using namespace std;
 4 
 5 const int mod = 998244353;
 6 const int gg = 3;
 7 
 8 int n,a,b;
 9 
10 vector<int> res[205000];
11 
12 int up[405000];
13 
14 int ord[405000];
15 
16 int fast_pow(int now,int pw){
17     if(pw == 0) return 1;
18     if(pw == 1) return now;
19     int z = fast_pow(now,pw/2);
20     z = (1ll*z*z)%mod;
21     if(pw & 1){z= (1ll*z*now)%mod;}
22     return z;
23 }
24 
25 void fft(int now,int len,int f){
26     for(int i=0;i<len;i++) if(i<ord[i]) swap(res[now][i],res[now][ord[i]]);
27     for(int i=1;i<len;i<<=1){
28     int wn = fast_pow(gg,(mod-1)/(i<<1));
29     if(f == -1) wn = fast_pow(wn,mod-2);
30     for(int j=0;j<len;j+=(i<<1)){
31         for(int k=0,w=1;k<i;k++,w = (1ll*w*wn)%mod){
32         int x = res[now][j+k],y = (1ll*w*res[now][j+k+i])%mod;
33         res[now][j+k] = (x+y)%mod;
34         res[now][j+k+i] = (x-y+mod)%mod;
35         }
36     }
37     }
38     if(f == -1){
39     int iv = fast_pow(len,mod-2);
40     for(int i=0;i<len;i++) res[now][i] = (1ll*res[now][i]*iv)%mod;
41     }
42 }
43 
44 void multi(int p1,int p2){
45     int n1 = res[p1].size()-1,n2 = res[p2].size()-1;
46     int len = 1,om = 0;
47     while(len <= (n1+n2+1))len<<=1,om++;
48     for(int i=n1+1;i<len;i++) res[p1].push_back(0);
49     for(int i=n2+1;i<len;i++) res[p2].push_back(0);
50     for(int i=0;i<len;i++) ord[i] = (ord[i>>1]>>1)+((i&1)<<om-1);
51     fft(p1,len,1);fft(p2,len,1);
52     for(int i=0;i<len;i++){
53     res[p1][i] = (1ll*res[p1][i]*res[p2][i])%mod;
54     if(res[p1][i] < 0) res[p1][i]+=mod;
55     }
56     fft(p1,len,-1);
57     res[p2].clear();
58 }
59 
60 void divide(int l,int r,int now){
61     if(l == r) {up[now] = l;return;}
62     int mid = (l+r)/2;
63     divide(l,mid,now<<1);
64     divide(mid+1,r,now<<1|1);
65     multi(up[now<<1],up[now<<1|1]);
66     up[now] = up[now<<1];
67 }
68 
69 void work(){
70     if(a == 0 || b == 0){puts("0");return;}
71     if(n == 1){if(a+b==2)puts("1"); else puts("0"); return;}
72     int c = 1;
73     if(a<b) swap(a,b);
74     if(a-1 > a+b-2) c = 0;
75     for(int i=1;i<=a-1;i++){
76     c = (1ll*c*(a+b-1-i))%mod;
77     c = (1ll*c*fast_pow(i,mod-2))%mod;
78     }
79     for(int i=1;i<n;i++) res[i].push_back(i-1),res[i].push_back(1);
80     divide(1,n-1,1);
81     c = (1ll*c*res[up[1]][a+b-2])%mod;
82     printf("%d",c);
83 }
84 
85 int main(){
86     scanf("%d%d%d",&n,&a,&b);
87     work();
88     return 0;
89 }

 

posted @ 2018-04-28 08:24  menhera  阅读(663)  评论(1编辑  收藏  举报