【模板】多项式半家桶 version 1

以下代码必须开 -O2

点击查看代码
#include <algorithm>
#include <cassert>
#include <cstdio>
#include <cstring>
#include <vector>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr, ##__VA_ARGS__)
#else
#define debug(...) void(0)
#endif
typedef long long LL;
template <unsigned P> struct modint {
    unsigned v; modint() : v(0) {}
    template <class T> modint(T x) { x %= (int)P, v = x < 0 ? x + P : x; }
    modint operator+() const { return *this; }
    modint operator-() const { return modint(0) - *this; }
    modint inv() const { return assert(v), qpow(*this, P - 2); }
    friend int raw(const modint &self) { return self.v; }
    template <class T> friend modint qpow(modint a, T b) {
        modint r = 1;
        for (; b; b >>= 1, a *= a) if (b & 1) r *= a;
        return r;
    }
    modint &operator+=(const modint &rhs) { if (v += rhs.v, v >= P) v -= P; return *this; }
    modint &operator-=(const modint &rhs) { if (v -= rhs.v, v >= P) v += P; return *this; }
    modint &operator*=(const modint &rhs) { v = 1ull * v * rhs.v % P; return *this; }
    modint &operator/=(const modint &rhs) { return *this *= rhs.inv(); }
    friend modint operator+(modint lhs, const modint &rhs) { return lhs += rhs; }
    friend modint operator-(modint lhs, const modint &rhs) { return lhs -= rhs; }
    friend modint operator*(modint lhs, const modint &rhs) { return lhs *= rhs; }
    friend modint operator/(modint lhs, const modint &rhs) { return lhs /= rhs; }
    friend bool operator==(const modint &lhs, const modint &rhs) { return lhs.v == rhs.v; }
    friend bool operator!=(const modint &lhs, const modint &rhs) { return lhs.v != rhs.v; }
};
typedef modint<998244353> mint;

四则运算 & 求值

分别为 addition, subtract, multiple(bf), divide(bf), modulo(bf), getValue
subtract(a, b) = addition(a, multiple(b, {-1}))

modulo(a, b) = a - (a / b) * b

std 的 <functional> 中有五个甚至六个类模板刻画算数操作: plus, minus, multiplies, divides, modulus, negate。还有乘法不能写 time 因为 std 抢注成为时间了。所以多项式板子就不能写这些名字了(还好乘法和除法加了三单要不然更加难以想象)

以及 apply & invoke 已经被 std 抢注,所以多项式代入求值怎么写函数名呢?getValue.

vector<mint> addition(const vector<mint> &a, const vector<mint> &b) {
    vector<mint> c(max(a.size(), b.size()));
    for (int i = 0; i < a.size(); i++) c[i] += a[i];
    for (int i = 0; i < b.size(); i++) c[i] += b[i];
    return c;
}
vector<mint> multiple(const vector<mint> &a, const vector<mint> &b) {
    vector<mint> c(a.size() + b.size() - 1);
    for (int i = 0; i < a.size(); i++) {
        for (int j = 0; j < b.size(); j++) c[i + j] += a[i] * b[j];
    }
    return c;
}
vector<mint> divide(vector<mint> a, vector<mint> b) {
    vector<mint> res(a.size() - b.size() + 1);
    reverse(b.begin(), b.end());
    for (int i = (int) a.size() - 1; i >= b.size() - 1; i--) {
        mint coe = res[i - b.size() + 1] = a[i] / b[0];
        for (int j = i; j > i - b.size(); j--) a[j] -= coe * b[i - j];
    }
    return res;
}
mint getValue(const vector<mint> &a, mint x) {
    mint res = 0;
    for (int i = (int) a.size() - 1; i >= 0; i--)
        res = res * x + a[i];
    return res;
}

NTT & multiple

int glim(const int &x){return 1 << (32 - __builtin_clz(x));}
int bitctz(const int &x){return __builtin_ctz(x);}
vector<mint> getWns(mint G) {
    vector<mint> wns(23);
    int cnt = 0;
    generate(wns.begin(), wns.end(),
             [=]() mutable { return qpow(G, (mint::mod - 1) >> ++cnt); });
    return wns;
}
const vector<mint> wns = getWns(3);
void ntt(vector<mint> &a, const int &op) {
    const int n = a.size();
    for (int i = 1, r = 0; i < n; i++) {
        r ^= n - (1 << (bitctz(n) - bitctz(i) - 1));
        if (i < r) swap(a[i], a[r]);
    }
    vector<mint> w(n);
    for (int k = 1, len = 2; len <= n; k <<= 1, len <<= 1) {
        const mint wn = wns[bitctz(k)];
        for (int i = raw(w[0] = 1); i < k; i++) w[i] = w[i - 1] * wn;
        for (int i = 0; i < n; i += len) {
            for (int j = 0; j < k; j++) {
                const mint x = a[i + j], y = a[i + j + k] * w[j];
                a[i + j] = x + y, a[i + j + k] = x - y;
            }
        }
    }
    if (op == -1) {
        const mint iz = mint(1) / n;
        for (int i = 0; i < n; i++) a[i] *= iz;
        reverse(a.begin() + 1, a.end());
    }
}
vector<mint> multiple(vector<mint> a, vector<mint> b) {
    int rLen = a.size() + b.size() - 1, len = glim(rLen);
    a.resize(len), ntt(a, 1);
    b.resize(len), ntt(b, 1);
    for (int i = 0; i < len; i++) a[i] *= b[i];
    ntt(a, -1), a.resize(rLen);
    return a;
}

getInv

vector<mint> getInv(const vector<mint> &a, int lim) {
    vector<mint> b = {1 / a[0]};
    for (int len = 2; len <= glim(lim); len <<= 1) {
        vector<mint> c(a.begin(), a.begin() + min(len, (int)a.size()));
        b.resize(len << 1), ntt(b, 1);
        c.resize(len << 1), ntt(c, 1);
        for (int i = 0; i < len << 1; i++)
            b[i] = b[i] * (2 - c[i] * b[i]);
        ntt(b, -1), b.resize(len);
    }
    b.resize(lim);
    return b;
}

divide & modulo

vector<mint> divide(vector<mint> f, vector<mint> g) {
    if (f.size() < g.size()) return {};
    int rLen = f.size() - g.size() + 1;
    reverse(f.begin(), f.end());
    reverse(g.begin(), g.end());
    f = multiple(f, getInv(g, rLen));
    f.resize(rLen), reverse(f.begin(), f.end());
    return f;
}
vector<mint> modulo(vector<mint> f, vector<mint> g) {
    int rLen = g.size() - 1;
    vector<mint> q = multiple(g, divide(f, g));
    q.resize(rLen), f.resize(rLen);
    for (int i = 0; i < rLen; i++) f[i] -= q[i];
    return f;
}

常系数齐次线性递推

vector<mint> qpow(vector<mint> a, int b, vector<mint> m) {
    vector<mint> r = {1};
    for (; b; b >>= 1, a = modulo(multiple(a, a), m)) {
        if (b & 1) r = modulo(multiple(r, a), m);
    }
    return r;
}
int main() {
    int n, k;
    scanf("%d%d", &n, &k);
    vector<mint> m(k + 1), a(k);
    m[k] = 1;
    for (int i = k - 1, x; i >= 0; i--) scanf("%d", &x), m[i] = -x;
    for (int i = 0, x; i < k; i++) scanf("%d", &x), a[i] = x;
    vector<mint> b = qpow({0, 1}, n, m);
    mint ans = 0;
    for (int i = 0; i < k; i++) ans += b[i] * a[i];
    printf("%d\n", raw(ans));
    return 0; 
}

拉格朗日插值

\(\ell_j(x)=\prod_{i\neq j}\frac{x-x_i}{x_j-x_i}\)

\(f(x)=\sum_i\ell_i(x)y_i\)

vector<mint> lagrange(const vector<pair<mint, mint>> &a) {
    vector<mint> ans, product = {1};
    for (int i = 0; i < a.size(); i++) 
        product = multiple(product, {-a[i].first, 1});
    for (int i = 0; i < a.size(); i++) {
        mint denos = 1;
        for (int j = 0; j < a.size(); j++) {
            if (i != j) denos *= a[i].first - a[j].first;
        }
        vector<mint> numes = divide(product, {-a[i].first, 1});
        ans = addition(ans, multiple(numes, {a[i].second / denos}));
    }
    return ans;
}

针对拉格朗日插值所需要的除以 \((x-x_i)\) 的减少常数写法:

vector<mint> divide2(vector<mint> a, mint b1) {
    vector<mint> res(a.size() - 1);
    for (int i = (int) a.size() - 1; i >= 1; i--) {
        mint coe = res[i - 1] = a[i];
        a[i - 1] -= a[i] * b1;
    }
    return res;
}
vector<mint> lagrange(const vector<mint> &a, const vector<mint> &b) {
    assert(a.size() == b.size());
    vector<mint> ans(a.size()), product = {1};
    for (int i = 0; i < a.size(); i++) 
        product = multiple(product, {-a[i], 1});
    for (int i = 0; i < a.size(); i++) {
        mint denos = 1;
        for (int j = 0; j < a.size(); j++)
            if (i != j) denos *= a[i] - a[j];
        vector<mint> numes = divide2(product, -a[i]);
        mint coe = b[i] / denos;
        for (int j = 0; j < a.size(); j++) ans[j] += numes[j] * coe;
    }
    return ans;
}

这里 divide2 也可以写退背包,顺序倒过来自己推一下。divide 在保留特定项数之后会寄。

posted @ 2023-09-24 21:36  caijianhong  阅读(19)  评论(0编辑  收藏  举报