P3301 [SDOI2013]方程

思路

容斥的挺好的练习题
对于第二个条件,可以直接使m减去suma2,使得第二个条件舍去,然后m再减去n,使得问题转化成有n1个变量要满足小于等于某个数的条件,其他的随便取,求整数解的个数
对n1,以2^n的复杂度枚举至少哪些不符合限制,然后容斥(至少0个-至少1个+至少2个....)
然后用隔板法可以得到每一次答案为

\[\left(\begin{matrix}m-midt-1\\n-1\end{matrix}\right) \]

注意本题模数不是质数,需要EXLucas,同时由于本题卡时间,所以要预处理MOD的质因数和mul函数要用的阶乘

代码

#include <cstdio>
#include <cstring>
#include <algorithm>
#define int long long
using namespace std;
int pow(int a,int b,int MOD){
    int ans=1;
    while(b){
        if(b&1)
            ans=(1LL*ans*a)%MOD;
        a=(1LL*a*a)%MOD;
        b>>=1;
    }
    return ans;
}
int exgcd(int a,int b,int &x,int &y){
    if(b==0){
        x=1;
        y=0;
        return a;
    }
    int req=exgcd(b,a%b,x,y);
    int t=x;
    x=y;
    y=t-a/b*y;
    return req;
}
int inv(int a,int p){
    if(!a) 
        return 0;
    int x,y;
    exgcd(a,p,x,y);
    x=((x%p+p)%p);
    if(!x)
        x+=p;
    return x;
}
int f[10210];
int mul(int n,int pi,int pk){//get n!/pi^a%p^k
    if(n<pi)
        return f[n];
    return 1LL*pow(f[pk-1],n/pk,pk)*f[n%pk]%pk*mul(n/pi,pi,pk)%pk;
}
int C(int n,int m,int Mod,int pi,int pk){
    if(m>n)
        return 0;
    f[0]=1;
    for(int i=1;i<=pk;i++)
        if(i%pi)
            f[i]=(f[i-1]*i)%pk;
        else
            f[i]=f[i-1];
    int jcn=mul(n,pi,pk),jcm=mul(m,pi,pk),jcnm=mul(n-m,pi,pk),k=0;
    for(int i=n;i;i/=pi)
        k+=i/pi;
    for(int i=m;i;i/=pi)
        k-=i/pi;
    for(int i=n-m;i;i/=pi)
        k-=i/pi;
    int ans=1LL*jcn*inv(jcm,pk)%pk*inv(jcnm,pk)%pk*pow(pi,k,pk)%pk;
    return 1LL*ans*(Mod/pk)%Mod*inv(Mod/pk,pk)%Mod;
}
int exLucas(int n,int m,int Mod){
    int ans=0;
    if(Mod==10007){
        ans=(ans+C(n,m,Mod,10007,10007))%Mod;
    }
    else if(Mod==262203414){
        ans=(ans+C(n,m,Mod,2,2)%Mod+C(n,m,Mod,3,3)%Mod+C(n,m,Mod,11,11)%Mod+C(n,m,Mod,397,397)%Mod+C(n,m,Mod,10007,10007)%Mod)%Mod;
    }
    else{
        ans=(ans+C(n,m,Mod,5,125)+C(n,m,Mod,7,343)+C(n,m,Mod,101,10201))%Mod;
    }
    return ans;
}
int n,n1,n2,m,A[20],MOD,ans,T,sum2;
int bitcount(int x){
    int ans=0;
    while(x){
        ans++;
        x&=(x-1);
    }
    return ans;
}
int bi[(1<<9)];
signed main(){
    scanf("%lld %lld",&T,&MOD);
    for(int i=0;i<(1<<8);i++)
        bi[i]=bitcount(i);
    while(T--){
        memset(A,0,sizeof(A));
        ans=0;
        sum2=0;
        scanf("%lld %lld %lld %lld",&n,&n1,&n2,&m);
        for(int i=1;i<=n1+n2;i++){
            scanf("%lld",&A[i]);
            if(i>n1)
                sum2+=A[i]-1;
        }
        m-=sum2;
        for(int i=0;i<(1<<(n1));i++){
            int midt=0,midcnt=0;
            for(int j=1;j<=n1;j++)
                if((i>>(j-1))&1)
                    midt+=A[j],midcnt++;
            ans=(ans+(1LL*((bi[i]&1)?-1:1)*exLucas(m-midt-1,n-1,MOD)%MOD+MOD))%MOD;
        } 
        printf("%lld\n",ans);
    }
    return 0;
}
posted @ 2019-03-14 23:56  dreagonm  阅读(160)  评论(0编辑  收藏  举报