计算大组合数

在题目要求对组合数取模时,可以用卢卡斯定理等方法,但是要求算出组合数的具体值时,就需要用高精度计算大组合数了

原理

对于素数 \(p\)\(n!\) 中的幂次 \(cnt(n,p)\) ,有如下公式:

\[cnt(n,p)=\sum_{i=1}^{\infty} \lfloor\frac{n}{p^{i}}\rfloor \]

可以得出素数 \(p\) 在组合数 \(\displaystyle \binom{n}{m}\) 中的幂次为:

\[cnt(n,p)-cnt(n-m,p)-cnt(m,p) \]

所以我们可以用线性筛预处理出 \(1\sim n\) 的素数,并计算它们在组合数的幂次,将组合数转化为标准分解的形式,再用高精度乘法,即可得到答案

代码

#include<bits/stdc++.h>
using namespace std;

const int MAX_N = 600 + 5;
int prime[MAX_N];
bool is_prime[MAX_N];
int cnt;

struct bint {
    vector<int> v;

    bint() {
        *this = 0;
    }
    
    bint(int x) {
        *this = x;
    }

    bint& operator=(int x) {
        v.clear();
        do {
            v.push_back(x % 10);
        } while (x /= 10);
        return *this;
    }

    bint& operator=(const bint& x) {
        v.resize(x.v.size());
        for(int i = v.size() - 1; i >= 0; i--)
            v[i] = x.v[i];
        return *this;
    }
};

ostream& operator<<(ostream& out, const bint& x)
{
    for(int i = x.v.size() - 1; i >= 0; i--)
        out << (char)(x.v[i] + '0');
    return out;
}

istream& operator>>(istream& in, bint& x)
{
    string str;
    in >> str;
    x.v.clear();
    for(int i = str.size() - 1; i >= 0; i--)
        x.v.push_back(str[i] - '0');
    return in;
}

bint operator+(const bint& a, const bint& b)
{
    bint res;
    res.v.clear();
    bool carry = false;
    int len = max(a.v.size(), b.v.size());
    for(int i = 0; i < len; i++) {
        int add = 0;
        if(i < (int)a.v.size()) add += a.v[i];
        if(i < (int)b.v.size()) add += b.v[i];
        if(carry) {
            add++;
            carry = false;
        }
        if(add >= 10) {
            add -= 10;
            carry = true;
        }
        res.v.push_back(add);
    }
    if(carry)
        res.v.push_back(1);
    return res;
}

bint operator*(const bint& a, const bint& b)
{
    bint res;
    res.v.resize(a.v.size() + b.v.size());
    for(int i = 0; i < (int)a.v.size(); i++) {
        for(int j = 0; j < (int)b.v.size(); j++) {
            res.v[i + j] += a.v[i] * b.v[j];
            res.v[i + j + 1] += res.v[i + j] / 10;
            res.v[i + j] %= 10;
        }
    }
    int siz = res.v.size();
    while(siz > 1 && res.v[siz - 1] == 0)
        siz--;
    res.v.resize(siz);
    return res;
}

void sieve(int n)
{
    memset(is_prime, 1, sizeof(is_prime));
    is_prime[1] = false;
    cnt = 0;
    for(int i = 2; i <= n; i++) {
        if(is_prime[i])
            prime[++cnt] = i;
        for(int j = 1; j <= cnt && i * prime[j] <= n; j++) {
            is_prime[i * prime[j]] = false;
            if(i % prime[j] == 0)
                break;
        }
    }
}

int get_cnt(int n, int p)
{
    int res = 0, t = p;
    while(t <= n) {
        res += (n / t);
        t *= p;
    }
    return res;
}

bint C(int n, int m)
{
    if(m > n)
        return bint(0);
    bint res = 1;
    for(int i = 1; i <= cnt && prime[i] <= n; i++) {
        int t = get_cnt(n, prime[i]) - get_cnt(n - m, prime[i]) - get_cnt(m, prime[i]);
        for(int j = 1; j <= t; j++)
            res = res * prime[i];
    }
    return res;
}

int main()
{
    int n, m;
    cin >> n >> m;
    sieve(n);
    cout << C(n, m) << endl;
    return 0;
}
posted @ 2021-12-20 15:39  f(k(t))  阅读(82)  评论(0编辑  收藏  举报