【BZOJ3129】方程(SDOI2013)-容斥原理+扩展Lucas定理
测试地址:方程
做法:本题需要用到容斥原理+扩展Lucas定理。
首先,如果没有任何限制,那么非负整数解的数量就是,这个可以用隔板法求出,那么要求正整数解的话,其实只要转化成求的非负整数解数量即可,显然上面的方程可以转化为来求。
现在我们考虑限制,对于第二种限制,我们可以把转化为,就可以转化成求另一个方程的正整数解数。然而第一种限制就不好办了,这时候看到,想到容斥,即我们如果要求满足所有条件的方案数,我们可以用没有任何限制的方案数,减去强制某一个条件满足不了的方案数,加上强制某两个条件满足不了的方案数……显然如果一个限制满足不了,就会变成的形式,这个就能很容易地化成第二种限制求解了。
那么最后一个问题,也是这道题最困难的一个点,就是求大组合数取模。我们知道如果为质数,可以用Lucas定理求出,但这里可能是合数,我们只能将其质因数分解为这样的形式,然后分别求组合数对取模的结果,然后用中国剩余定理或合并模线性方程的方法将最后的结果求出。
现在问题变成求组合数对取模的结果,其中为质数。我们可以根据公式将问题转化为求对取模的结果。我们把拆成长度为的若干段,然后把所有的倍数提出来,显然这些东西乘起来等于,对于递归求解,对于剩下的部分,我们可以把连续的段看成一个周期,因为,那么我们可以先直接快速幂求出完整的个周期的乘积,然后剩下的部分最多长为,直接计算即可。这一个部分的时间复杂度应该是,如果预处理出缺项前缀积(就是把的倍数舍掉的阶乘)常数会小很多。
这里要使用欧拉定理:来求出逆元,显然由欧拉函数的定义有。特别地,如果结果中包含因子,那么我们无法直接求得逆元,所以我们在计算时要独立计算因子的幂数,仅对不含因子的部分求逆元即可。
BZOJ的题面中缺了一个比较重要的信息,数据范围中的是固定的一些数,这些数中最大的可分解出的差不多在左右的水准,所以上述的的算法可以通过此题。
以下是本人代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int T,n1,n2,totcnt=0;
ll n,m,a[20],p,fac[10],cnt[10],F[20010];
ll power(ll a,ll b,ll p)
{
ll s=1,ss=a;
while(b)
{
if (b&1) s=s*ss%p;
ss=ss*ss%p,b>>=1;
}
return s;
}
ll exgcd(ll a,ll b)
{
ll x0=1,y0=0,x1=0,y1=1;
while(b)
{
ll q=a/b,tmp;
tmp=x0,x0=x1,x1=tmp-q*x1;
tmp=y0,y0=y1,y1=tmp-q*y1;
tmp=a,a=b,b=tmp%b;
}
return x0;
}
ll calc(ll n,ll p,ll t,ll P,ll &s)
{
s=0;
if (n<=0) return 1;
ll x=1;
x=power(F[P],n/P,P);
x=x*F[n%P]%P;
x=x*calc(n/p,p,t,P,s)%P;
s+=n/p;
return x;
}
void init(ll p,ll P)
{
F[0]=1;
for(ll i=1;i<=P;i++)
{
if (i%p!=0) F[i]=F[i-1]*i%P;
else F[i]=F[i-1];
}
}
ll exlucas(ll n,ll m)
{
if (m>n) return 0;
ll lastr,lasta;
for(int i=1;i<=totcnt;i++)
{
ll x,sumt=0,t;
ll P=power(fac[i],cnt[i],p+1);
init(fac[i],P);
x=calc(n,fac[i],cnt[i],P,t);
sumt+=t;
x=x*power(calc(m,fac[i],cnt[i],P,t),P/fac[i]*(fac[i]-1)-1,P)%P;
sumt-=t;
x=x*power(calc(n-m,fac[i],cnt[i],P,t),P/fac[i]*(fac[i]-1)-1,P)%P;
sumt-=t;
x=x*power(fac[i],sumt,P)%P;
if (i>1)
{
ll x0;
x0=exgcd(lasta,P);
x0=(x0*(x-lastr)%P+P)%P;
lastr=(lastr+lasta*x0)%(lasta*P);
lasta=lasta*P;
}
else lastr=x,lasta=P;
}
return lastr;
}
int main()
{
scanf("%d%lld",&T,&p);
ll x=p;
for(ll i=2;i<=(ll)sqrt(p)+1;i++)
if (x%i==0)
{
fac[++totcnt]=i;
cnt[totcnt]=0;
while(x%i==0)
{
cnt[totcnt]++;
x/=i;
}
}
if (x!=1) fac[++totcnt]=x,cnt[totcnt]=1;
while(T--)
{
scanf("%lld%d%d%lld",&n,&n1,&n2,&m);
for(int i=1;i<=n1+n2;i++)
scanf("%lld",&a[i]);
for(int i=1;i<=n2;i++)
m-=a[n1+i]-1;
m-=n;
ll ans=0;
for(int i=0;i<(1<<n1);i++)
{
int tot=0;
ll x=m;
for(int j=0;j<n1;j++)
if (i&(1<<j))
{
tot++;
x-=a[j+1];
}
if (tot%2) ans=(ans-exlucas(x+n-1,n-1)+p)%p;
else ans=(ans+exlucas(x+n-1,n-1))%p;
}
printf("%lld\n",ans);
}
return 0;
}