扩展卢卡斯定理
解决的问题:
\(C_n^m\pmod{p}\),\(p\)不是质数。
对\(p\)分解质数,假设为\(p_1^{c_1}*p_2^{c_2}*...*p_k^{c_k}\)。
对每个\(p_i^{c_i}\)求出在模意义下的\(C_n^m\),设为\(a_i\),我们的问题就变为:
\(\begin{cases}x\equiv a_1\pmod{p_1^{c_1}}\\ x\equiv a_2\pmod{p_2^{c_2}}\\...\\x\equiv a_k\pmod{p_k^{c_k}}\end{cases}\)
求这个\(x\)即为答案,因为模数互质,显然可以用CRT求。
于是考虑求\(C_n^m\pmod{p_i^{c_i}}\)。
发现不能用逆元求\(n!\),因为不一定存在逆元,\(x\)存在逆元的条件为\(\gcd(x,p)=1\)。
现在问题就是求\(n!\)在模\(p_i^{c_i}\)的意义下的逆元,即求\(n!\)模\(p_i^{c_i}\)的值,考虑既然模数只有一个质因子,我们将\(n!\)中的质因子\(p_i\)提出,剩下的数必定与\(p_i^{c_i}\)互质,exgcd求逆元即可。
以下引用自
以\(22!\mod\ 3^2\)为例:
按照\(3^2\)分段:\((1*2*3*4*5*6*7*8*9)*(10*11*12*13*14*15*16*17*18)*(19*20*21*22)\)
将3提出后为\((3^6*(1*2*3*4*5*6*7))*(1*2*4*5*7*8)*(10*11*13*14*16*17)*(19*20*22)\)
观察发现前\(\lfloor\frac{n}{p_i^{c_i}}\rfloor\)模意义下相同,求一组时候快速幂即可。\((19*20*22)\)直接暴力算。\((1*2*3*4*5*6*7)\)递归即可。
提出的\(3\)因子的数目可以直接算,为\(\sum\limits_{p^i<=n}\lfloor\frac{n}{p^i}\rfloor\),证明见lyd蓝书。
于是求完了。
code:
#include<bits/stdc++.h>
using namespace std;
const int maxp=1000010;
typedef long long ll;
ll n,m,mod;
ll fac[maxp];
inline ll power(ll x,ll k,ll mod)
{
ll res=1%mod;
while(k)
{
if(k&1)res=res*x%mod;
x=x*x%mod;k>>=1;
}
return res;
}
ll calc(ll x,ll p,ll mod)
{
if(x<=1)return 1;
ll res=1;
if(x>=mod)res=power(fac[mod-1],x/mod,mod);
if(x%mod)res=res*fac[x%mod]%mod;
return res*calc(x/p,p,mod)%mod;
}
void exgcd(ll a,ll b,ll& x,ll& y)
{
if(!b){x=1,y=0;return;}
exgcd(b,a%b,x,y);
ll z=x;x=y,y=z-(a/b)*y;
}
inline ll inv(ll a,ll mod)
{
if(!a)return 0;
ll x,y;exgcd(a,mod,x,y);
x=(x%mod+mod)%mod;
return x;
}
inline ll C(ll n,ll m,ll p,ll mod)
{
if(m>n)return 0;
fac[0]=1;
for(ll i=1;i<mod;i++)fac[i]=fac[i-1]*((i%p)?i:1)%mod;
ll a=calc(n,p,mod),b=inv(calc(m,p,mod),mod),c=inv(calc(n-m,p,mod),mod);
ll res=a*b%mod*c%mod;
int cnt=0;
for(ll i=p;i<=n;i*=p)cnt+=n/i;
for(ll i=p;i<=m;i*=p)cnt-=m/i;
for(ll i=p;i<=n-m;i*=p)cnt-=(n-m)/i;
//cerr<<cnt<<endl;
return res*power(p,cnt,mod)%mod;
}
int main()
{
//freopen("test.in","r",stdin);
//freopen("test.out","w",stdout);
scanf("%lld%lld%lld",&n,&m,&mod);
ll tmp=mod,ans=0;
for(ll i=2;i*i<=tmp;i++)
{
if(tmp%i)continue;
ll now=1;
while(tmp%i==0)now*=i,tmp/=i;
ans=(ans+C(n,m,i,now)*(mod/now)%mod*inv(mod/now,now)%mod)%mod;
}
if(tmp>1)ans=(ans+C(n,m,tmp,tmp)*(mod/tmp)%mod*inv(mod/tmp,tmp)%mod)%mod;
printf("%lld",(ans%mod+mod)%mod);
return 0;
}