【BZOJ1272】Gate Of Babylon [Lucas][组合数][逆元]
Gate Of Babylon
Time Limit: 10 Sec Memory Limit: 162 MB[Submit][Status][Discuss]
Description
Input
Output
Sample Input
2 1 10 13
3
Sample Output
12
HINT
Main idea
有若干个没有限制的道具,以及T个有限制个数的道具,取出m个,求方案数。
Solution
首先,看到有限制的只有15个,因此可以考虑使用容斥原理:Ans=全部没有限制的方案-有1个超过限制的方案数+有2个超过限制的方案数-有3个超过限制的方案数…。
以此类推。我们先考虑没有限制的,在m组无限制的数中选n个的方案数,显然就是C(n+m-1,n)。
因为这道题是要求不超过m的方案数,也就是那么运用加法原理,发现答案也就是C(n+0-1,0)+C(n+1-1,1)+C(n+2-1,2)+...+C(n+m-1,m)=C(n+m,m)。
然后考虑有限制的情况,有一个超过限制直接用总数减去(这个的限制+1)就是当前的总数,相当于强制要选限制+1个为空。
然后只要DFS,记录到当前为止选了几个,答案要记是b[i]+1,判断加减,最后累加答案。
最后,n、m过大,发现p是一个质数,所以可以用Lucas定理:Lucas(n,m,p)=Lucas(n/p,m/p,p)*C(n%p,m%p),其中C(n%p,m%p)求的时候要用到乘法逆元。
Code
1 #include<iostream>
2 #include<string>
3 #include<algorithm>
4 #include<cstdio>
5 #include<cstring>
6 #include<cstdlib>
7 #include<cmath>
8 using namespace std;
9
10 const int ONE=1000001;
11
12 int n,T,m,MOD;
13 long long Ans;
14 long long Jc[ONE];
15 int b[ONE];
16
17 int get()
18 {
19 int res,Q=1; char c;
20 while( (c=getchar())<48 || c>57)
21 if(c=='-')Q=-1;
22 if(Q) res=c-48;
23 while((c=getchar())>=48 && c<=57)
24 res=res*10+c-48;
25 return res*Q;
26 }
27
28 long long Quickpow(int a,int b,int MOD)
29 {
30 long long res=1;
31 while(b)
32 {
33 if(b&1) res=res*a%MOD;
34 a=(long long)a*a%MOD;
35 b/=2;
36 }
37 return res;
38 }
39
40 int C(int m,int n)
41 {
42 if(m<n) return 0;
43 int up=Jc[m]%MOD;
44 int down=(long long)Jc[m-n]*Jc[n]%MOD;
45 return (long long)up*Quickpow(down,MOD-2,MOD)%MOD;
46 }
47
48 int Lucas(int n,int m,int MOD)
49 {
50 long long res=1;
51 if(n<m) return 0;
52 while(n && m)
53 {
54 res=res*C(n%MOD,m%MOD)%MOD;
55 n/=MOD; m/=MOD;
56 }
57 return res;
58 }
59
60 void Dfs(int len,int PD,int val)
61 {
62 if(len==T+1)
63 {
64 Ans+=PD*Lucas(n+m-val,m-val,MOD);
65 Ans+=MOD;
66 Ans%=MOD;
67 return;
68 }
69 Dfs(len+1,PD,val);
70 Dfs(len+1,-PD,val+b[len]+1);
71 }
72
73 int main()
74 {
75 n=get(); T=get(); m=get(); MOD=get();
76 Jc[0]=1; for(int i=1;i<=MOD;i++) Jc[i]=(long long)Jc[i-1]*i%MOD;
77 for(int i=1;i<=T;i++)
78 b[i]=get();
79 Dfs(1,1,0);
80 printf("%d",Ans);
81 }