计算大组合数
在题目要求对组合数取模时,可以用卢卡斯定理等方法,但是要求算出组合数的具体值时,就需要用高精度计算大组合数了
原理
对于素数 \(p\) 在 \(n!\) 中的幂次 \(cnt(n,p)\) ,有如下公式:
\[cnt(n,p)=\sum_{i=1}^{\infty} \lfloor\frac{n}{p^{i}}\rfloor
\]
可以得出素数 \(p\) 在组合数 \(\displaystyle \binom{n}{m}\) 中的幂次为:
\[cnt(n,p)-cnt(n-m,p)-cnt(m,p)
\]
所以我们可以用线性筛预处理出 \(1\sim n\) 的素数,并计算它们在组合数的幂次,将组合数转化为标准分解的形式,再用高精度乘法,即可得到答案
代码
#include<bits/stdc++.h>
using namespace std;
const int MAX_N = 600 + 5;
int prime[MAX_N];
bool is_prime[MAX_N];
int cnt;
struct bint {
vector<int> v;
bint() {
*this = 0;
}
bint(int x) {
*this = x;
}
bint& operator=(int x) {
v.clear();
do {
v.push_back(x % 10);
} while (x /= 10);
return *this;
}
bint& operator=(const bint& x) {
v.resize(x.v.size());
for(int i = v.size() - 1; i >= 0; i--)
v[i] = x.v[i];
return *this;
}
};
ostream& operator<<(ostream& out, const bint& x)
{
for(int i = x.v.size() - 1; i >= 0; i--)
out << (char)(x.v[i] + '0');
return out;
}
istream& operator>>(istream& in, bint& x)
{
string str;
in >> str;
x.v.clear();
for(int i = str.size() - 1; i >= 0; i--)
x.v.push_back(str[i] - '0');
return in;
}
bint operator+(const bint& a, const bint& b)
{
bint res;
res.v.clear();
bool carry = false;
int len = max(a.v.size(), b.v.size());
for(int i = 0; i < len; i++) {
int add = 0;
if(i < (int)a.v.size()) add += a.v[i];
if(i < (int)b.v.size()) add += b.v[i];
if(carry) {
add++;
carry = false;
}
if(add >= 10) {
add -= 10;
carry = true;
}
res.v.push_back(add);
}
if(carry)
res.v.push_back(1);
return res;
}
bint operator*(const bint& a, const bint& b)
{
bint res;
res.v.resize(a.v.size() + b.v.size());
for(int i = 0; i < (int)a.v.size(); i++) {
for(int j = 0; j < (int)b.v.size(); j++) {
res.v[i + j] += a.v[i] * b.v[j];
res.v[i + j + 1] += res.v[i + j] / 10;
res.v[i + j] %= 10;
}
}
int siz = res.v.size();
while(siz > 1 && res.v[siz - 1] == 0)
siz--;
res.v.resize(siz);
return res;
}
void sieve(int n)
{
memset(is_prime, 1, sizeof(is_prime));
is_prime[1] = false;
cnt = 0;
for(int i = 2; i <= n; i++) {
if(is_prime[i])
prime[++cnt] = i;
for(int j = 1; j <= cnt && i * prime[j] <= n; j++) {
is_prime[i * prime[j]] = false;
if(i % prime[j] == 0)
break;
}
}
}
int get_cnt(int n, int p)
{
int res = 0, t = p;
while(t <= n) {
res += (n / t);
t *= p;
}
return res;
}
bint C(int n, int m)
{
if(m > n)
return bint(0);
bint res = 1;
for(int i = 1; i <= cnt && prime[i] <= n; i++) {
int t = get_cnt(n, prime[i]) - get_cnt(n - m, prime[i]) - get_cnt(m, prime[i]);
for(int j = 1; j <= t; j++)
res = res * prime[i];
}
return res;
}
int main()
{
int n, m;
cin >> n >> m;
sieve(n);
cout << C(n, m) << endl;
return 0;
}