COGS 2189 帕秋莉的超级多项式
放模板啦!
以后打比赛的时候直接复制过来。
说句实话vector的效率真的不怎么样,但是似乎也还行,最主要是……写得比较爽。
#include <cstdio> #include <cstring> #include <algorithm> #include <vector> #include <cmath> using namespace std; typedef long long ll; typedef vector <ll> poly; namespace Poly { const int L = 1 << 20; const ll P = 998244353LL; int lim, pos[L]; inline ll fpow(ll x, ll y) { ll res = 1; for (; y > 0; y >>= 1) { if (y & 1) res = res * x % P; x = x * x % P; } return res; } const ll inv2 = fpow(2, P - 2); template <typename T> inline void inc(T &x, T y) { x += y; if (x >= P) x -= P; } template <typename T> inline void sub(T &x, T y) { x -= y; if (x < 0) x += P; } inline void prework(int len) { int l = 0; for (lim = 1; lim < len; lim <<= 1, ++l); for (int i = 0; i < lim; i++) pos[i] = (pos[i >> 1] >> 1) | ((i & 1) << (l - 1)); } inline void ntt(poly &c, int opt) { c.resize(lim, 0); for (int i = 0; i < lim; i++) if (i < pos[i]) swap(c[i], c[pos[i]]); for (int i = 1; i < lim; i <<= 1) { ll wn = fpow(3, (P - 1) / (i << 1)); if (opt == -1) wn = fpow(wn, P - 2); for (int len = i << 1, j = 0; j < lim; j += len) { ll w = 1; for (int k = 0; k < i; k++, w = w * wn % P) { ll x = c[j + k], y = w * c[j + k + i] % P; c[j + k] = (x + y) % P, c[j + k + i] = (x - y + P) % P; } } } if (opt == -1) { ll inv = fpow(lim, P - 2); for (int i = 0; i < lim; i++) c[i] = c[i] * inv % P; } } inline poly operator * (const poly &x, const poly &y) { poly res, u = x, v = y; prework(u.size() + v.size() - 1); ntt(u, 1), ntt(v, 1); for (int i = 0; i < lim; i++) res.push_back(v[i] * u[i] % P); ntt(res, -1); res.resize(u.size() + v.size() - 1); return res; } poly getInv(poly x, int len) { x.resize(len); if (len == 1) { poly res; res.push_back(fpow(x[0], P - 2)); return res; } poly y = getInv(x, (len + 1) >> 1); prework(len << 1); poly u = x, v = y, res; ntt(u, 1), ntt(v, 1); for (int i = 0; i < lim; i++) res.push_back(v[i] * (2LL - u[i] * v[i] % P + P) % P); ntt(res, -1); res.resize(len); return res; } inline void direv(poly &c) { for (int i = 0; i < (int)c.size() - 1; i++) c[i] = c[i + 1] * (i + 1) % P; c[c.size() - 1] = 0; } inline void integ(poly &c) { for (int i = (int)c.size() - 1; i > 0; i--) c[i] = c[i - 1] * fpow(i, P - 2) % P; c[0] = 0; } inline poly getLn(poly c) { poly a = getInv(c, (int)c.size()); poly b = c; direv(b); poly res = b * a; res.resize(c.size()); integ(res); return res; } poly getSqrt(poly x, int len) { x.resize(len); if (len == 1) { poly res; res.push_back(sqrt(x[0])); return res; } poly y = getSqrt(x, (len + 1) >> 1); poly u = x, v = y, w, res; w = getInv(y, len); prework(len << 1); ntt(u, 1), ntt(v, 1), ntt(w, 1); for (int i = 0; i < lim; i++) res.push_back((v[i] * v[i] % P + u[i]) % P * w[i] % P * inv2 % P); ntt(res, -1); res.resize(len); return res; } poly getExp(poly x, int len) { x.resize(len, 0); if (len == 1) { poly res; res.push_back(1); return res; } poly y = getExp(x, (len + 1) >> 1); poly u = x, v = y, w = y, res; w.resize(len, 0); w = getLn(w); prework(len << 1); u[0] = (u[0] + 1 - w[0] + P) % P; for (int i = 1; i < (int)u.size(); i++) u[i] = (u[i] - w[i] + P) % P; ntt(u, 1), ntt(v, 1); for (int i = 0; i < lim; i++) res.push_back(u[i] * v[i] % P); ntt(res, -1); res.resize(len); return res; } inline poly fpow(poly x, ll y, int n) { x = getLn(x); for (int i = 0; i < n; i++) x[i] = x[i] * y % P; x = getExp(x, n); return x; } } template <typename T> inline void read(T &X) { X = 0; char ch = 0; T op = 1; for (; ch > '9'|| ch < '0'; ch = getchar()) if (ch == '-') op = -1; for (; ch >= '0' && ch <= '9'; ch = getchar()) X = (X << 3) + (X << 1) + ch - 48; X *= op; } int main() { // freopen("Sample.txt", "r", stdin); freopen("polynomial.in", "r", stdin); freopen("polynomial.out", "w", stdout); int n, k; read(n), read(k); poly a; a.resize(n); for (int i = 0; i < n; i++) read(a[i]); a = Poly :: getSqrt(a, n); /* for (int i = 0; i < n; i++) printf("%I64d%c", a[i], " \n"[i == n - 1]); */ a = Poly :: getInv(a, n); /* for (int i = 0; i < n; i++) printf("%I64d%c", a[i], " \n"[i == n - 1]); */ Poly :: integ(a); /* for (int i = 0; i < n; i++) printf("%I64d%c", a[i], " \n"[i == n - 1]); */ a = Poly :: getExp(a, n); /* for (int i = 0; i < n; i++) printf("%I64d%c", a[i], " \n"[i == n - 1]); */ a = Poly :: getInv(a, n); Poly :: inc(a[0], 1LL); /* for (int i = 0; i < n; i++) printf("%I64d%c", a[i], " \n"[i == n - 1]); */ a = Poly :: getLn(a); Poly :: inc(a[0], 1LL); /* for (int i = 0; i < n; i++) printf("%I64d%c", a[i], " \n"[i == n - 1]); */ a = Poly :: fpow(a, k, n); /* for (int i = 0; i < n; i++) printf("%I64d%c", a[i], " \n"[i == n - 1]); */ Poly :: direv(a); for (int i = 0; i < n; i++) printf("%lld%c", a[i], " \n"[i == n - 1]); return 0; }