BZOJ3129: [Sdoi2013]方程
拓展Lucas+容斥原理
1 #include<cstdio> 2 #include<cstdlib> 3 #include<algorithm> 4 #include<cstring> 5 #include<vector> 6 #include<cmath> 7 #include<queue> 8 #define MAXN 10000+10 9 #define INF 0x7f7f7f7f 10 #define LINF 0x7f7f7f7f7f7f7f7f 11 #define ll long long 12 #define pb push_back 13 #define ft first 14 #define sc second 15 #define mp make_pair 16 #define pil pair<int,ll> 17 #define pll pair<ll,ll> 18 using namespace std; 19 struct Lucas{ 20 void extgcd(ll a,ll b,ll &x,ll &y){ 21 if(!b){x=1,y=0;} 22 else{ 23 ll xx,yy; 24 extgcd(b,a%b,xx,yy); 25 x=yy; 26 y=xx-a/b*yy; 27 } 28 } 29 ll Inv(ll a,ll b){ 30 ll x,y; 31 extgcd(a,b,x,y); 32 x=(x%b+b)%b; 33 if(!x)x+=b; 34 return x; 35 } 36 ll Pow(ll a,ll b,ll p){ 37 ll ret=1LL; 38 while(b){ 39 if(b&1){(ret*=a)%=p;} 40 (a*=a)%=p; 41 b>>=1; 42 } 43 return ret; 44 } 45 ll fac(ll n,ll pi,ll pk){ 46 if(!n)return 1LL; 47 ll ret=1LL; 48 for(ll i=2;i<pk;i++){ 49 if(i%pi)(ret*=i)%=pk; 50 } 51 ret=Pow(ret,n/pk,pk); 52 for(ll i=2;i<=(n%pk);i++){ 53 if(i%pi)(ret*=i)%=pk; 54 } 55 return ret*fac(n/pi,pi,pk)%pk; 56 } 57 ll C(ll n,ll m,ll pi,ll pk){ 58 ll a=fac(n,pi,pk),b=fac(m,pi,pk),c=fac(n-m,pi,pk); 59 ll t=0LL; 60 for(ll i=n/pi;i;i/=pi)t+=i; 61 for(ll i=m/pi;i;i/=pi)t-=i; 62 for(ll i=(n-m)/pi;i;i/=pi)t-=i; 63 ll ret=a*Inv(b,pk)*Inv(c,pk)%pk; 64 (ret*=Pow(pi,t,pk))%=pk; 65 return ret; 66 } 67 ll n,m,p; 68 vector<pll> pn; 69 ll init(ll pp){ 70 p=pp; 71 ll x=sqrt(pp*1.0); 72 for(ll i=2;i<=x;i++){ 73 if(pp%i==0){ 74 ll pk=1LL; 75 while(pp%i==0){ 76 pp/=i; 77 pk*=i; 78 } 79 pn.pb(mp(i,pk)); 80 } 81 } 82 if(pp^1){ 83 pn.pb(mp(pp,pp)); 84 } 85 } 86 ll solve(ll n,ll m){ 87 ll ans=0LL,pi,pk; 88 for(int i=0;i<pn.size();i++){ 89 pi=pn[i].ft,pk=pn[i].sc; 90 ll t=C(n,m,pi,pk); 91 (t*=(p/pk))%=p; 92 (t*=Inv(p/pk,pk))%=p; 93 (ans+=t)%=p; 94 } 95 return ans; 96 } 97 }L; 98 int T,n,n1,n2,m; 99 int a[10]; 100 ll ans,p; 101 ll calc(ll n,ll m){ 102 return L.solve(m+n-1,min(m,n-1)); 103 } 104 void rc(int k,int m,int f){ 105 if(m<0)return; 106 ans+=f*calc(n,m); 107 ans=(ans%p+p)%p; 108 for(int i=k+1;i<=n1;i++){ 109 rc(i,m-a[i],-f); 110 } 111 } 112 void solve(){ 113 scanf("%d%d%d%d",&n,&n1,&n2,&m); 114 m-=n; 115 for(int i=1;i<=n1;i++){ 116 scanf("%d",&a[i]); 117 } 118 int t; 119 for(int i=1;i<=n2;i++){ 120 scanf("%d",&t); 121 m-=(t-1); 122 } 123 if(m<0){ 124 printf("0\n"); 125 return; 126 } 127 ans=0LL; 128 rc(0,m,1); 129 printf("%lld\n",ans); 130 } 131 int main() 132 { 133 //freopen("data.in","r",stdin); 134 scanf("%d%lld",&T,&p); 135 L.init(p); 136 while(T--){ 137 solve(); 138 } 139 return 0; 140 }