关于求阶乘和阶乘逆元的预处理和加速

因为求逆元的复杂度其实比较高,所以我们要尽可能地少用快速幂求逆元。

在下面代码中只用快速幂求了一次逆元,其余均是线性复杂度。

vector<Z> fac(n + 1, 1), invfac(n + 1);
    for (int i = 1; i <= n; i++) {
        fac[i] = fac[i - 1] * i;  		//阶乘
    }
    invfac[n] = fac[n].inv();  			//唯一一次快速幂求逆元
    for (int i = n; i; i--) {
        invfac[i - 1] = invfac[i] * i;	//阶乘逆元
    }

完整代码(包含重载):

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

constexpr int mod = 1e9 + 7;

template <class T>
T power(T a, int b) {
    T res = 1;
    for (; b; b >>= 1, a *= a)
        if (b & 1)
            res *= a;
    return res;
}

int norm(int x) {
    if (x < 0) x += mod;
    if (x >= mod) x -= mod;
    return x;
}
struct Z {
    int x;
    Z(int x = 0) : x(norm(x)) {}
    int val() const {
        return x;
    }
    Z operator-() const {
        return Z(norm(mod - x));
    }
    Z inv() const {
        assert(x != 0);
        return power(*this, mod - 2);
    }
    Z &operator*=(const Z &rhs) {
        x = ll(x) * rhs.x % mod;
        return *this;
    }
    Z &operator+=(const Z &rhs) {
        x = norm(x + rhs.x);
        return *this;
    }
    Z &operator-=(const Z &rhs) {
        x = norm(x - rhs.x);
        return *this;
    }
    Z &operator/=(const Z &rhs) {
        return *this *= rhs.inv();
    }
    friend Z operator*(const Z &lhs, const Z &rhs) {
        Z res = lhs;
        res *= rhs;
        return res;
    }
    friend Z operator+(const Z &lhs, const Z &rhs) {
        Z res = lhs;
        res += rhs;
        return res;
    }
    friend Z operator-(const Z &lhs, const Z &rhs) {
        Z res = lhs;
        res -= rhs;
        return res;
    }
    friend Z operator/(const Z &lhs, const Z &rhs) {
        Z res = lhs;
        res /= rhs;
        return res;
    }
};

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int n, k;
    cin >> n >> k;
    vector<Z> fac(n + 1, 1), invfac(n + 1);
    for (int i = 1; i <= n; i++) {
        fac[i] = fac[i - 1] * i;
    }
    invfac[n] = fac[n].inv();
    for (int i = n; i; i--) {
        invfac[i - 1] = invfac[i] * i;
    }

    Z ans = 0;
    for (int i = 0; i <= min(k, n); i++) {
        ans += fac[n] * invfac[i] * invfac[n - i];
    }

    cout << ans.val() << '\n';

    return 0;
}
posted @ 2022-09-05 09:04  blockche  阅读(134)  评论(0编辑  收藏  举报