【模板】多项式半家桶 version 1
以下代码必须开 -O2
header
点击查看代码
#include <algorithm>
#include <cassert>
#include <cstdio>
#include <cstring>
#include <vector>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr, ##__VA_ARGS__)
#else
#define debug(...) void(0)
#endif
typedef long long LL;
template <unsigned P> struct modint {
unsigned v; modint() : v(0) {}
template <class T> modint(T x) { x %= (int)P, v = x < 0 ? x + P : x; }
modint operator+() const { return *this; }
modint operator-() const { return modint(0) - *this; }
modint inv() const { return assert(v), qpow(*this, P - 2); }
friend int raw(const modint &self) { return self.v; }
template <class T> friend modint qpow(modint a, T b) {
modint r = 1;
for (; b; b >>= 1, a *= a) if (b & 1) r *= a;
return r;
}
modint &operator+=(const modint &rhs) { if (v += rhs.v, v >= P) v -= P; return *this; }
modint &operator-=(const modint &rhs) { if (v -= rhs.v, v >= P) v += P; return *this; }
modint &operator*=(const modint &rhs) { v = 1ull * v * rhs.v % P; return *this; }
modint &operator/=(const modint &rhs) { return *this *= rhs.inv(); }
friend modint operator+(modint lhs, const modint &rhs) { return lhs += rhs; }
friend modint operator-(modint lhs, const modint &rhs) { return lhs -= rhs; }
friend modint operator*(modint lhs, const modint &rhs) { return lhs *= rhs; }
friend modint operator/(modint lhs, const modint &rhs) { return lhs /= rhs; }
friend bool operator==(const modint &lhs, const modint &rhs) { return lhs.v == rhs.v; }
friend bool operator!=(const modint &lhs, const modint &rhs) { return lhs.v != rhs.v; }
};
typedef modint<998244353> mint;
四则运算 & 求值
分别为 addition, subtract, multiple(bf), divide(bf), modulo(bf), getValue
subtract(a, b) = addition(a, multiple(b, {-1}))
modulo(a, b) = a - (a / b) * b
std 的 <functional>
中有五个甚至六个类模板刻画算数操作: plus, minus, multiplies, divides, modulus, negate。还有乘法不能写 time 因为 std 抢注成为时间了。所以多项式板子就不能写这些名字了(还好乘法和除法加了三单要不然更加难以想象)
以及 apply & invoke 已经被 std 抢注,所以多项式代入求值怎么写函数名呢?getValue.
vector<mint> addition(const vector<mint> &a, const vector<mint> &b) {
vector<mint> c(max(a.size(), b.size()));
for (int i = 0; i < a.size(); i++) c[i] += a[i];
for (int i = 0; i < b.size(); i++) c[i] += b[i];
return c;
}
vector<mint> multiple(const vector<mint> &a, const vector<mint> &b) {
vector<mint> c(a.size() + b.size() - 1);
for (int i = 0; i < a.size(); i++) {
for (int j = 0; j < b.size(); j++) c[i + j] += a[i] * b[j];
}
return c;
}
vector<mint> divide(vector<mint> a, vector<mint> b) {
vector<mint> res(a.size() - b.size() + 1);
reverse(b.begin(), b.end());
for (int i = (int) a.size() - 1; i >= b.size() - 1; i--) {
mint coe = res[i - b.size() + 1] = a[i] / b[0];
for (int j = i; j > i - b.size(); j--) a[j] -= coe * b[i - j];
}
return res;
}
mint getValue(const vector<mint> &a, mint x) {
mint res = 0;
for (int i = (int) a.size() - 1; i >= 0; i--)
res = res * x + a[i];
return res;
}
NTT & multiple
int glim(const int &x){return 1 << (32 - __builtin_clz(x));}
int bitctz(const int &x){return __builtin_ctz(x);}
vector<mint> getWns(mint G) {
vector<mint> wns(23);
int cnt = 0;
generate(wns.begin(), wns.end(),
[=]() mutable { return qpow(G, (mint::mod - 1) >> ++cnt); });
return wns;
}
const vector<mint> wns = getWns(3);
void ntt(vector<mint> &a, const int &op) {
const int n = a.size();
for (int i = 1, r = 0; i < n; i++) {
r ^= n - (1 << (bitctz(n) - bitctz(i) - 1));
if (i < r) swap(a[i], a[r]);
}
vector<mint> w(n);
for (int k = 1, len = 2; len <= n; k <<= 1, len <<= 1) {
const mint wn = wns[bitctz(k)];
for (int i = raw(w[0] = 1); i < k; i++) w[i] = w[i - 1] * wn;
for (int i = 0; i < n; i += len) {
for (int j = 0; j < k; j++) {
const mint x = a[i + j], y = a[i + j + k] * w[j];
a[i + j] = x + y, a[i + j + k] = x - y;
}
}
}
if (op == -1) {
const mint iz = mint(1) / n;
for (int i = 0; i < n; i++) a[i] *= iz;
reverse(a.begin() + 1, a.end());
}
}
vector<mint> multiple(vector<mint> a, vector<mint> b) {
int rLen = a.size() + b.size() - 1, len = glim(rLen);
a.resize(len), ntt(a, 1);
b.resize(len), ntt(b, 1);
for (int i = 0; i < len; i++) a[i] *= b[i];
ntt(a, -1), a.resize(rLen);
return a;
}
getInv
vector<mint> getInv(const vector<mint> &a, int lim) {
vector<mint> b = {1 / a[0]};
for (int len = 2; len <= glim(lim); len <<= 1) {
vector<mint> c(a.begin(), a.begin() + min(len, (int)a.size()));
b.resize(len << 1), ntt(b, 1);
c.resize(len << 1), ntt(c, 1);
for (int i = 0; i < len << 1; i++)
b[i] = b[i] * (2 - c[i] * b[i]);
ntt(b, -1), b.resize(len);
}
b.resize(lim);
return b;
}
divide & modulo
vector<mint> divide(vector<mint> f, vector<mint> g) {
if (f.size() < g.size()) return {};
int rLen = f.size() - g.size() + 1;
reverse(f.begin(), f.end());
reverse(g.begin(), g.end());
f = multiple(f, getInv(g, rLen));
f.resize(rLen), reverse(f.begin(), f.end());
return f;
}
vector<mint> modulo(vector<mint> f, vector<mint> g) {
int rLen = g.size() - 1;
vector<mint> q = multiple(g, divide(f, g));
q.resize(rLen), f.resize(rLen);
for (int i = 0; i < rLen; i++) f[i] -= q[i];
return f;
}
常系数齐次线性递推
vector<mint> qpow(vector<mint> a, int b, vector<mint> m) {
vector<mint> r = {1};
for (; b; b >>= 1, a = modulo(multiple(a, a), m)) {
if (b & 1) r = modulo(multiple(r, a), m);
}
return r;
}
int main() {
int n, k;
scanf("%d%d", &n, &k);
vector<mint> m(k + 1), a(k);
m[k] = 1;
for (int i = k - 1, x; i >= 0; i--) scanf("%d", &x), m[i] = -x;
for (int i = 0, x; i < k; i++) scanf("%d", &x), a[i] = x;
vector<mint> b = qpow({0, 1}, n, m);
mint ans = 0;
for (int i = 0; i < k; i++) ans += b[i] * a[i];
printf("%d\n", raw(ans));
return 0;
}
拉格朗日插值
\(\ell_j(x)=\prod_{i\neq j}\frac{x-x_i}{x_j-x_i}\)
\(f(x)=\sum_i\ell_i(x)y_i\)
vector<mint> lagrange(const vector<pair<mint, mint>> &a) {
vector<mint> ans, product = {1};
for (int i = 0; i < a.size(); i++)
product = multiple(product, {-a[i].first, 1});
for (int i = 0; i < a.size(); i++) {
mint denos = 1;
for (int j = 0; j < a.size(); j++) {
if (i != j) denos *= a[i].first - a[j].first;
}
vector<mint> numes = divide(product, {-a[i].first, 1});
ans = addition(ans, multiple(numes, {a[i].second / denos}));
}
return ans;
}
针对拉格朗日插值所需要的除以 \((x-x_i)\) 的减少常数写法:
vector<mint> divide2(vector<mint> a, mint b1) {
vector<mint> res(a.size() - 1);
for (int i = (int) a.size() - 1; i >= 1; i--) {
mint coe = res[i - 1] = a[i];
a[i - 1] -= a[i] * b1;
}
return res;
}
vector<mint> lagrange(const vector<mint> &a, const vector<mint> &b) {
assert(a.size() == b.size());
vector<mint> ans(a.size()), product = {1};
for (int i = 0; i < a.size(); i++)
product = multiple(product, {-a[i], 1});
for (int i = 0; i < a.size(); i++) {
mint denos = 1;
for (int j = 0; j < a.size(); j++)
if (i != j) denos *= a[i] - a[j];
vector<mint> numes = divide2(product, -a[i]);
mint coe = b[i] / denos;
for (int j = 0; j < a.size(); j++) ans[j] += numes[j] * coe;
}
return ans;
}
这里 divide2 也可以写退背包,顺序倒过来自己推一下。divide 在保留特定项数之后会寄。
本文来自博客园,作者:caijianhong,转载请注明原文链接:https://www.cnblogs.com/caijianhong/p/17726728.html