Luogu 4238 【模板】多项式求逆
疯狂补板中。
考虑倍增实现。
假设多项式只有一个常数项,直接对它逆元就可以了。
现在假如要求$G(x)$
$$F(x)G(x) \equiv 1 (\mod x^n)$$
而我们已经求出了$H(x)$
$$F(x)H(x) \equiv 1(\mod x^{\left \lceil \frac{n}{2} \right \rceil})$$
两式相减,
$$F(x)(G(x) - H(x)) \equiv 0(\mod x^{\left \lceil \frac{n}{2} \right \rceil})$$
$F(x) \mod x^{\left \lceil \frac{n}{2} \right \rceil}$一定不会是$0$,那么
$$G(x) - H(x) \equiv 0(\mod x^{\left \lceil \frac{n}{2} \right \rceil})$$
两边平方,
$$G(x)^2 + H(x)^2 - 2G(x)H(x) \equiv 0(\mod x^n)$$
注意到后面的模数也平方了。
因为多项式$G(x) - H(x)$次数$\in [0, \left \lceil \frac{n}{2} \right \rceil]$的项的系数全都是$0$,所以平方之后次数在$[0, n]$之间的项的系数也全都是$0$。
两边乘上$F(x)$,
$$F(x)G(x)^2 + F(x)H(x)^2 - 2F(x)G(x)H(x) \equiv G(x) + F(x)H(x)^2 - 2H(x) \equiv 0(\mod x^n)$$
就得到了
$$G(x) \equiv 2H(x) - F(x)H(x)^2(\mod x^n)$$
递归实现比较清爽,非递归的比递归的快挺多的。
时间复杂度为$O(nlogn)$。
实现的时候有两个小细节:
1、$H(x)$的长度是$\frac{n}{2}$的,$F(x)$的长度是$n$,所以$F(x)H(x)^2$的长度是$2n$。
2、递归的时候注意那个上取整符号。
Code:
#include <cstdio> #include <cstring> #include <algorithm> using namespace std; typedef long long ll; const int N = 1 << 20; int n; ll f[N], g[N]; namespace Poly { const int L = 1 << 20; const ll gn = 3; const ll Mod[4] = {0, 998244353LL, 1004535809LL, 469762049LL}; int lim, pos[L]; inline ll fmul(ll x, ll y, ll P) { ll res = 0; for (x %= P; y; y >>= 1) { if (y & 1) res = (res + x) % P; x = (x + x) % P; } return res; } inline ll fpow(ll x, ll y, ll P) { ll res = 1LL; for (x %= P; y > 0; y >>= 1) { if (y & 1) res = res * x % P; x = x * x % P; } return res; } inline void prework(int len) { int l = 0; for (lim = 1; lim < len; lim <<= 1, ++l); for (int i = 0; i < lim; i++) pos[i] = (pos[i >> 1] >> 1) | ((i & 1) << (l - 1)); } inline void ntt(ll *c, ll opt, ll P) { for (int i = 0; i < lim; i++) if (i < pos[i]) swap(c[i], c[pos[i]]); for (int i = 1; i < lim; i <<= 1) { ll wn = fpow(gn, (P - 1) / (i << 1), P); if (opt == -1) wn = fpow(wn, P - 2, P); for (int len = i << 1, j = 0; j < lim; j += len) { ll w = 1; for (int k = 0; k < i; k++, w = w * wn % P) { ll x = c[j + k], y = c[j + k + i] * w % P; c[j + k] = (x + y) % P, c[j + k + i] = (x - y + P) % P; } } } if (opt == -1) { ll inv = fpow(lim, P - 2, P); for (int i = 0; i < lim; i++) c[i] = c[i] * inv % P; } } /* inline ll get(int k, ll P) { ll M = Mod[1] * Mod[2]; ll t1 = fmul(Mod[2] * ans[1][k] % M, fpow(Mod[1], Mod[2] - 2, Mod[2]), M); ll t2 = fmul(Mod[1] * ans[2][k] % M, fpow(Mod[2], Mod[1] - 2, Mod[1]), M); ll t = (t1 + t2) % M; ll res = (ans[3][k] - t % Mod[3] + Mod[3]) % Mod[3]; res = res * fpow(M, Mod[3] - 2, Mod[3]) % Mod[3]; res = ((res % P) * (M % P) % P + (t % P)) % P; return res; } */ ll f[L], g[L]; void inv(ll *a, ll *b, int len, ll P) { if (len == 1) { b[0] = fpow(a[0], P - 2, P); return; } inv(a, b, (len + 1) >> 1, P); prework(len << 1); for (int i = 0; i < lim; i++) f[i] = g[i] = 0; for (int i = 0; i < len; i++) f[i] = a[i], g[i] = b[i]; ntt(f, 1, P), ntt(g, 1, P); for (int i = 0; i < lim; i++) g[i] = g[i] * (2LL - f[i] * g[i] % P + P) % P; ntt(g, -1, P); for (int i = 0; i < len; i++) b[i] = g[i]; } }; template <typename T> inline void read(T &X) { X = 0; char ch = 0; T op = 1; for (; ch > '9'|| ch < '0'; ch = getchar()) if (ch == '-') op = -1; for (; ch >= '0' && ch <= '9'; ch = getchar()) X = (X << 3) + (X << 1) + ch - 48; X *= op; } int main() { read(n); for (int i = 0; i < n; i++) read(f[i]); Poly :: inv(f, g, n, Poly :: Mod[1]); for (int i = 0; i < n; i++) printf("%lld%c", g[i], i == (n - 1) ? '\n' : ' '); return 0; }
#include <cstdio> #include <cstring> using namespace std; typedef long long ll; const int N = 3e5 + 5; const ll P = 998244353LL; int n, lim, pos[N]; ll a[N], f[2][N], tmp[N]; template <typename T> inline void read(T &X) { X = 0; char ch = 0; T op = 1; for (; ch > '9' || ch < '0'; ch = getchar()) if (ch == '-') op = -1; for (; ch >= '0' && ch <= '9'; ch = getchar()) X = (X << 3) + (X << 1) + ch - 48; X *= op; } template <typename T> inline void swap(T &x, T &y) { T t = x; x = y; y = t; } inline ll fpow(ll x, ll y) { ll res = 1LL; for (; y > 0; y >>= 1) { if (y & 1) res = res * x % P; x = x * x % P; } return res; } inline void ntt(ll *c, int opt) { for (int i = 0; i < lim; i++) if(i < pos[i]) swap(c[i], c[pos[i]]); for (int i = 1; i < lim; i <<= 1) { ll wn = fpow(3, (P - 1) / (i << 1)); if(opt == -1) wn = fpow(wn, P - 2); for (int len = i << 1, j = 0; j < lim; j += len) { ll w = 1; for (int k = 0; k < i; k++, w = w * wn % P) { ll x = c[j + k], y = c[j + k + i] * w % P; c[j + k] = (x + y) % P, c[j + k + i] = (x - y + P) % P; } } } if (opt == -1) { ll inv = fpow(lim, P - 2); for (int i = 0; i < lim; i++) c[i] = c[i] * inv % P; } } int main() { read(n); for (int i = 0; i < n; i++) read(a[i]); f[0][0] = fpow(a[0], P - 2); int dep = 1; for (int len = 1; len < n; len <<= 1, ++dep) { lim = len << 1; for (int i = 0; i < lim; i++) tmp[i] = a[i]; lim <<= 1; for (int i = 0; i < lim; i++) pos[i] = (pos[i >> 1] >> 1) | ((i & 1) << dep); for (int i = (len << 1); i < lim; i++) tmp[i] = 0; int now = dep & 1, pre = (dep - 1) & 1; ntt(f[pre], 1), ntt(tmp, 1); for (int i = 0; i < lim; i++) f[now][i] = (2LL * f[pre][i] % P - tmp[i] * f[pre][i] % P * f[pre][i] % P + P) % P; ntt(f[now], -1); for (int i = (len << 1); i < lim; i++) f[now][i] = 0; } --dep; for (int i = 0; i < n; i++) printf("%lld%c", f[dep & 1][i], i == (n - 1) ? '\n' : ' '); return 0; }