[SDOI2013]方程

\(\geq x\)的条件直接减掉。
\(\leq x\)的容斥。
\(exlucas\)

#include<iostream>
#include<cstdio>
#define ll long long 

void exgcd(ll a,ll b,ll &x,ll &y){
    if (!b) return (void)(x=1,y=0);
    exgcd(b,a%b,x,y);
    ll tmp=x;x=y;y=tmp-a/b*y;
}

ll gcd(ll a,ll b){
	if(b == 0)return a;
	return gcd(b,a % b);
}

inline ll inv(ll a,ll p){
	ll x,y;
	exgcd(a,p,x,y);
	return (x + p) % p;
}

inline ll lcm(ll a,ll b){
	return a / gcd(a,b) * b;
}

inline ll fastpow(ll a,ll b,ll p){
	ll ans = 1;
	a %= p;
	while(b){
		if(b & 1)ans = (ans * a) % p;
		b >>= 1;
		a = (a * a) % p;
	}
	return ans;
}

inline ll read(){
	ll ans = 0;
	char a = getchar();
	while(!(a <= '9' && a >= '0'))a = getchar();
	while(a <= '9' && a >= '0')ans = (ans << 3) + (ans << 1) + (a - '0'),a = getchar();
	return ans;
}

inline ll f(ll n,ll p,ll pk){
	if(n == 0)return 1;
	ll rou = 1;
	ll res = 1;
	for(ll i = 1;i <= pk;++i)
	if(i % p)rou = rou * i % pk;
	rou = fastpow(rou,n / pk,pk);
	for(ll i = pk * (n / pk);i <= n;++i)
	if(i % p)res = res * (i % pk) % pk;
	return f(n / p,p,pk) * rou % pk * res % pk;
}

inline ll g(ll n,ll p){
	if(n < p)return 0;
	return g(n / p,p) + (n / p);
}

inline ll c_pk(ll n,ll m,ll p,ll pk){
	ll fn = f(n,p,pk),fm = inv(f(m,p,pk),pk),fnm = inv(f(n - m,p,pk),pk);
	ll mi = fastpow(p,g(n,p) - g(m,p) - g(n - m,p),pk);
	return fn % pk * fm % pk * fnm % pk * mi % pk;
}

ll A[1001],B[1001];

// x = B(mod A)

inline ll exlucas(ll n,ll m,ll p){
	if(n < m || n <= 0)return 0;
	ll P = p,tot = 0;
	for(ll i = 2;i * i <= P;++i){
		if(!(p % i)){
			ll pk = 1;
			while(!(p % i))
			pk *= i,p /= i;
			A[++tot] = pk;
			B[tot] = c_pk(n,m,i,pk);
		}
	}
	if(p != 1)
	A[++tot] = p,B[tot] = c_pk(n,m,p,p);
	ll ans = 0;
	for(ll i = 1;i <= tot;++i){
		ll M = P / A[i],t = inv(M,A[i]);
		ans = (ans + B[i] * M % P * t % P) % P;
	}
	return ans;
}

ll a[200];

int main(){
	ll T,p;
	scanf("%lld%lld",&T,&p);
	while(T -- ){
		ll n,n1,n2,m,ans = 0;
		scanf("%lld%lld%lld%lld",&n,&n1,&n2,&m);
		for(int i = 1;i <= n1;++i)
		scanf("%lld",&a[i]);
		for(int i = 1;i <= n2;++i){
			ll x;
			scanf("%lld",&x);
			m -= x - 1; 
		}
		for(int S = 0;S < (1ll << n1);++S){
			int cnt = 0;
			ll tmp = m;
			for(int i = 0;i <= n1 - 1;++i)
			if((S >> i) & 1)++cnt,tmp -= a[i + 1];
			ans = (ans + (cnt & 1 ? -1 : 1) * exlucas(tmp - 1,n - 1,p) % p) % p;
		}
		std::cout<<(ans + p) % p<<std::endl;
	}
}

posted @ 2021-09-08 22:24  fhq_treap  阅读(55)  评论(0编辑  收藏  举报