求组合数
求组合数有以下四种情形 :
- 由公式 \(C_n^m=C_{n-1}^{n-1}+C_{n-1}^n\) 递推. 时间复杂度 \(O(nm)\). 一般 \(N\le 2000\).
- 预处理出阶乘, 再由 \(C_n^m=\frac{m!}{b!(a-b)!}\) 直接计算. 时间复杂度 \(O(NlogN)\). 一般 \(N\le 1e5\).
- 卢卡斯定理 : \(C_a^b=C_{a\%p}^{b\%p}C_{a/p}^{b/p}\). 时间复杂度 \(O(Plog_PN)\). 一般 \(a,b \le 1e18, p \le 1e5\).
- 高精度不取模, 将 \(C_n^m=\frac{m!}{b!(a-b)!}\) 的质数以及对应的次数求出. 将其转化为 \(p_1^{a_1}*p_2^{a_2}*p_3^{a_3}...p_k^{a_k}\) 的形式, 然后高精度乘法求.
下面是四种情形的代码
递推 :
void init() {
for (int i = 0; i < N; i++)
for (int j = 0; j <= i; j++) {
if (!j)
c[i][j] = 1;
else
c[i][j] = (c[i - 1][j - 1] + c[i - 1][j]) % mod;
}
}
预处理阶乘 :
void init() {
fact[0] = infact[0] = 1;
for (int i = 1; i < N; ++i) fact[i] = (ll)fact[i - 1] * i % mod;
for (int i = 1; i < N; ++i) infact[i] = infact[i - 1] * qmi(i, mod - 2) % mod; // 求逆元
}
int C(int a, int b) {
if (a < b) return 0;
return (ll)fact[a] * infact[b] % mod * infact[a - b] % mod;
}
卢卡斯定理 :
int C(int a, int b, int p) {
if (b > a) return 0;
int res = 1;
for (int i = 1, j = a; i <= b; i++, j--) {
res = (ll)res * j % p;
res = (ll)res * qmi(i, p - 2, p) % p;
}
return res;
}
/*
int C(int a, int b, int p)
{
if (a < b) return 0;
int down = 1, up = 1;
for (int i = a, j = 1; j <= b; i --, j ++ )
{
up = (ll)up * i % p;
down = (ll)down * j % p;
}
return (ll)up * qmi(down, p - 2, p) % p;
}
*/
int lucas(ll a, ll b, int p) {
if (a < p && b < p) return C(a, b, p);
return (ll)C(a % p, b % p, p) * lucas(a / p, b / p, p) % p;
}
高精度 :
void init(int n) {
for (int i = 2; i <= n; ++i) {
if (!st[i]) pri[cnt++] = i;
for (int j = 0; pri[j] * i <= n; ++j) {
st[pri[j] * i] = true;
if (i % pri[j] == 0) break;
}
}
}
int get(int n, int p) {
int res = 0;
while (n) res += n / p, n /= p;
return res;
}
vector<int> mul(vector<int> a, int b) {
int t = 0;
vector<int> c;
for (int i = 0; i < a.size(); ++i) {
t += a[i] * b;
c.pb(t % 10);
t /= 10;
}
while (t) c.pb(t % 10), t /= 10;
return c;
}
int main() {
//freopen("in.txt", "r", stdin);
IO;
int a, b;
cin >> a >> b;
init(a);
for (int i = 0; i < cnt; ++i)
sum[i] = get(a, pri[i]) - get(b, pri[i]) - get(a - b, pri[i]);
vector<int> v;
v.pb(1);
for (int i = 0; i < cnt; ++i)
for (int j = 0; j < sum[i]; ++j)
v = mul(v, pri[i]);
for (int i = v.size() - 1; i >= 0; --i) cout << v[i];
cout << '\n';
return 0;
}