[洛谷P5205]【模板】多项式开根
题目大意:给你$n$项多项式$A(x)$,求出$B(x)$满足$B^2(x)\equiv A(x)\pmod{x^n}$
题解:考虑已经求出$B_0(x)$满足$B_0^2(x)\equiv A(x)\pmod{x^{\lceil\frac n 2\rceil}}$
$$
B(x)-B_0(x)\equiv0\pmod{x^{\lceil\frac n 2\rceil}}\\
B^2(x)−2B(x)B_0(x)+B_0^2(x)≡0\pmod{x^n}\\
A(x)-2B(x)B_0(x)+B_0^2(x)≡0\pmod{x^n}\\
B(x)\equiv\dfrac{A(x)+B_0^2(x)}{2B_0(x)}\pmod{x^n}\\
$$
update:(2019-2-10)
$$
B(x)\equiv\dfrac{A(x)+B_0^2(x)}{2B_0(x)}\pmod{x^n}\\
B(x)\equiv\dfrac{A(x)}{2B_0(x)}+\dfrac{B_0(x)}2\pmod{x^n}\\
$$
发现$\dfrac{B_0(x)}2$只会影响$B(x)$数组的前半部分(即$\pmod{x^{\lceil\frac n2\rceil}}$的部分),但是$B(x)\equiv B_0(x)\pmod{x^{\lceil\frac n2\rceil}}$,所以可以不做考虑,直接把$B_0(x)$拉过来
卡点:求$INV$时注意清空数组,防止因为$B$数组不干净导致出锅
C++ Code:
#include <algorithm> #include <cctype> #include <cstdio> #define maxn 262144 const int mod = 998244353, __2 = mod + 1 >> 1; namespace std { struct istream { #define M (1 << 21 | 3) char buf[M], *ch = buf - 1; inline istream() { fread(buf, 1, M, stdin); } inline istream& operator >> (int &x) { while (isspace(*++ch)); for (x = *ch & 15; isdigit(*++ch); ) x = x * 10 + (*ch & 15); return *this; } #undef M } cin; struct ostream { #define M (1 << 21 | 3) char buf[M], *ch = buf - 1; inline ostream& operator << (int x) { if (!x) {*++ch = '0'; return *this;} static int S[20], *top; top = S; while (x) {*++top = x % 10 ^ 48; x /= 10;} for (; top != S; --top) *++ch = *top; return *this; } inline ostream& operator << (const char x) {*++ch = x; return *this;} inline ~ostream() { fwrite(buf, 1, ch - buf + 1, stdout); } #undef M } cout; } namespace Math { inline int pw(int base, int p) { static int res; for (res = 1; p; p >>= 1, base = static_cast<long long> (base) * base % mod) if (p & 1) res = static_cast<long long> (res) * base % mod; return res; } inline int inv(int x) { return pw(x, mod - 2); } } inline void reduce(int &x) { x += x >> 31 & mod; } inline void clear(register int *l, const int *r) { if (l >= r) return ; while (l != r) *l++ = 0; } namespace Poly { #define N maxn int lim, s, rev[N]; int Wn[N + 1]; inline void init(const int n) { lim = 1, s = -1; while (lim < n) lim <<= 1, ++s; for (register int i = 1; i < lim; ++i) rev[i] = rev[i >> 1] >> 1 | (i & 1) << s; const int t = Math::pw(3, (mod - 1) / lim); *Wn = 1; for (register int *i = Wn; i != Wn + lim; ++i) *(i + 1) = static_cast<long long> (*i) * t % mod; } inline void FFT(int *A, const int op = 1) { for (register int i = 1; i < lim; ++i) if (i < rev[i]) std::swap(A[i], A[rev[i]]); for (register int mid = 1; mid < lim; mid <<= 1) { const int t = lim / mid >> 1; for (register int i = 0; i < lim; i += mid << 1) for (register int j = 0; j < mid; ++j) { const int X = A[i + j], Y = static_cast<long long> (A[i + j + mid]) * Wn[t * j] % mod; reduce(A[i + j] += Y - mod), reduce(A[i + j + mid] = X - Y); } } if (!op) { const int ilim = Math::inv(lim); for (register int *i = A; i != A + lim; ++i) *i = static_cast<long long> (*i) * ilim % mod; std::reverse(A + 1, A + lim); } } void INV(int *A, int *B, int n) { if (n == 1) { *B = Math::inv(*A); return ; } const int len = n + 1 >> 1; INV(A, B, len); init(len * 3); static int C[N], D[N]; std::copy(A, A + n, C); clear(C + n, C + lim); std::copy(B, B + len, D); clear(D + len, D + lim); FFT(D), FFT(C); for (register int i = 0; i < lim; ++i) D[i] = (2 - static_cast<long long> (D[i]) * C[i] % mod + mod) * D[i] % mod; FFT(D, 0); std::copy(D + len, D + n, B + len); } void SQRT(int *A, int *B, int n) { if (n == 1) { *B = 1; return ; } static int C[N], D[N]; const int len = n + 1 >> 1; SQRT(A, B, len); INV(B, D, n), clear(D + n, D + lim); std::copy(A, A + n, C); clear(C + n, C + lim); FFT(C), FFT(D); for (register int i = 0; i < lim; ++i) D[i] = static_cast<long long> (C[i]) * D[i] % mod * __2 % mod; FFT(D, 0); std::copy(D + len, D + n, B + len); } #undef N } int n, A[maxn], B[maxn]; int main() { std::cin >> n; for (int i = 0; i < n; ++i) std::cin >> A[i]; Poly::SQRT(A, B, n); for (int i = 0; i < n; ++i) std::cout << B[i] << ' '; std::cout << '\n'; return 0; }