数论杂记——快速求解组合数 C(n,m) 取模
模板:
#include <bits/stdc++.h> using namespace std; typedef long long ll; const ll mod = 998244353; const int Max = 1e6 + 10; ll fact[Max],ifact[Max]; ll n,m; ll pow_mod(ll n,ll k) { ll res=1; n=n%mod; while (k>0) { if (k&1) res=res*n%mod; n=n*n%mod; k>>=1; } return res; } void init() { fact[0]=ifact[0]=1; for (int i=1;i<Max;i++) { fact[i]=(fact[i-1]*i)%mod; ifact[i]=pow_mod(fact[i],mod-2); } } ll C(ll n, ll m) { if (n<m||m<0) return 0; return (fact[n]*ifact[m]%mod)*ifact[n-m]%mod; } int main() { init(); cout<<C(?,?)<<endl; return 0; }
例题:CodeForces1312D
#include <bits/stdc++.h> using namespace std; typedef long long ll; const ll mod = 998244353; const int Max = 1e6 + 10; ll fact[Max],ifact[Max]; ll n,m; ll pow_mod(ll n,ll k) { ll res=1; n=n%mod; while (k>0) { if (k&1) res=res*n%mod; n=n*n%mod; k>>=1; } return res; } void init() { fact[0]=ifact[0]=1; for (int i=1;i<Max;i++) { fact[i]=(fact[i-1]*i)%mod; ifact[i]=pow_mod(fact[i],mod-2); } } ll C(ll n, ll m) { if (n<m||m<0) return 0; return (fact[n]*ifact[m]%mod)*ifact[n-m]%mod; } int main() { init(); cin>>n>>m; ll ans=0; for (int i=n-1;i<=m;i++) { ans+=C(i-1,n-2)%mod*(n-2)%mod*pow_mod(2,n-3)% mod; } cout<<ans%mod; return 0; }