[洛谷P3803] 【模板】多项式乘法(FFT, NTT)
题目大意:$FFT$,给你两个多项式,请输出乘起来后的多项式。
题解:$FFT$,由于给的$n$不是很大,也可以用$NTT$做
卡点:无
C++ Code:
FFT:
#include <cstdio> #include <cmath> using namespace std; const double Pi = acos(-1); int n, m; struct complex { double r, i; complex (double a = 0, double b = 0) {r = a, i = b;} complex operator + (complex a) {return (complex) {r + a.r, i + a.i};} complex operator - (complex a) {return (complex) {r - a.r, i - a.i};} complex operator /= (int a) {r /= a, i /= a;} complex operator * (complex a) {return (complex) {r * a.r - i * a.i, r * a.i + i * a.r};} } a[500000], b[500000]; int rev[500000], dig, l; void swap(complex &a, complex &b) {complex t = a; a = b; b = t;} void FFT(complex *a, int op) { for (int i = 0; i < l; i++) if (i < rev[i]) swap(a[i], a[rev[i]]); for (int mid = 1; mid < l; mid <<= 1 ) { complex Wn(cos(Pi / mid), op * sin(Pi / mid)); for (int i = 0; i < l; i += (mid << 1)) { complex W(1, 0); for (int j = 0; j < mid; j++, W = W * Wn) { complex X = a[i + j], Y = W * a[i + j + mid]; a[i + j] = X + Y; a[i + j + mid] = X - Y; } } } if (op == -1) for (int i = 0; i <= l; i++) a[i] /= l; } int main() { scanf("%d%d", &n, &m); for (int i = 0; i <= n; i++) scanf("%lf", &a[i].r); for (int i = 0; i <= m; i++) scanf("%lf", &b[i].r); l = 1; while (l <= (n + m)) l <<= 1, dig++; for (int i = 0; i < l; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (dig - 1)); FFT(a, 1), FFT(b, 1); for (int i = 0; i < l; i++) a[i] = a[i] * b[i]; FFT(a, -1); for (int i = 0; i <= n + m; i++) printf("%d ", int(a[i].r + 0.5)); return 0; }
NTT:
#include <cstdio> #define int long long using namespace std; const int maxn = 2100010; const int mod = 998244353; const int P = 3, invP = 332748118; int n, m; int a[maxn], b[maxn], rev[maxn], l, dig; int Inv[2040826], invl; inline void swap(int &a, int &b) {a ^= b ^= a ^= b;} int inv(int i) { if (i < 2040826) { if (Inv[i]) return Inv[i]; return (Inv[i] = inv(mod % i) * (mod - mod / i) % mod); }else return inv(mod % i) * (mod - mod / i) % mod; } inline int pw(int base, int p) { int ans = 1; for (p <<= 1; p >>= 1; (base *= base) %= mod) if (p & 1) (ans *= base) %= mod; return ans; } void NTT(int *a, int op) { int Yx; if (op == 1) Yx = P; else Yx = invP; for (int i = 0; i < l; i++) if (i < rev[i]) swap(a[i], a[rev[i]]); for (int mid = 1; mid < l; mid <<= 1) { int Wn = pw(Yx, (mod - 1) / (mid << 1)); for (int i = 0; i < l; i += (mid << 1)) { int W = 1; for (int j = 0; j < mid; j++, W = W * Wn % mod) { int X = a[i + j], Y = W * a[i + j + mid] % mod; a[i + j] = (X + Y) % mod; a[i + j + mid] = (X - Y + mod) % mod; } } } if (op == -1) for (int i = 0; i < l; i++) a[i] = (a[i] * invl) % mod; } signed main() { Inv[0] = Inv[1] = 1; scanf("%lld%lld", &n, &m); for (int i = 0; i <= n; i++) scanf("%lld", &a[i]); for (int i = 0; i <= m; i++) scanf("%lld", &b[i]); l = 1; while (l <= n + m) l <<= 1, dig++; invl = inv(l); for (int i = 1; i < l; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (dig - 1)); NTT(a, 1), NTT(b, 1); for (int i = 0; i < l; i++) (a[i] *= b[i]) %= mod; NTT(a, -1); for (int i = 0; i <= n + m; i++) printf("%lld ", a[i]); return 0; }