【数学】【多项式】多项式求逆

写在前面

多项式求逆

前置知识:NTT

多项式求逆

给定一个多项式 \(F\left(x\right)\),求一个多项式 \(G\left(x\right)\),使得 \(F\left(x\right)G\left(x\right)\equiv 1\left(\bmod 998244353\right)\)

考虑递归求解。

假定现在已经求出了 \(G_0\left(x\right)\),满足

\[F\left(x\right)G_0\left(x\right)\equiv 1\left(\bmod x^{\lceil\frac{n}{2}\rceil}\right)\tag 1 \]

根据要求的 \(G\left(x\right)\) 的定义,显然有

\[F\left(x\right)G\left(x\right) \equiv 1\left(\bmod x^{\lceil\frac{n}{2}\rceil}\right)\tag 2 \]

\((2) - (1)\),得

\[F\left(x\right)\left(G\left(x\right) - G_0\left(x\right)\right) \equiv 0 \left(\bmod x^{\lceil\frac{n}{2}\rceil}\right) \]

因为 \(F\left(x\right) \not\equiv 0\left(\bmod x^{\lceil \frac{n}{2}\rceil}\right)\),所以有

\[G\left(x\right) - G_0\left(x\right) \equiv 0 \left(\bmod x^{\lceil\frac{n}{2}\rceil}\right) \]

两边平方,得

\[G^2\left(x\right) - 2G\left(x\right)G_0\left(x\right) + G_0^2\left(x\right) \equiv 0\left(\bmod x^n\right) \]

两边同乘 \(F\left(x\right)\),得

\[G\left(x\right) - 2G_0\left(x\right) + F\left(x\right)G_0^2\left(x\right) \equiv 0\left(\bmod x^n\right) \]

移项整理

\[G\left(x\right) \equiv 2G_0\left(x\right) - F\left(x\right)G_0^2\left(x\right) \left(\bmod x^n\right) \]

递归处理之后自下而上递推即可。

代码:

int rev[Maxn];
void Setrev(int len) {
	for(int i = 1; i < len; ++i) {
		rev[i] = rev[i >> 1] >> 1;
		if(i & 1) rev[i] |= (len >> 1);
	}
}

void ntt(LL p[], int len, int type) {
	for(int i = 0; i < len; ++i) if(i < rev[i]) swap(p[i], p[rev[i]]); 
	for(int h = 2; h <= len; h <<= 1) {
		LL gn = qpow(g[type], (Mod - 1) / h);
		for(int j = 0; j < len; j += h) {
			LL gk = 1;
			for(int k = j; k < j + h / 2; ++k) {
				LL e = p[k] % Mod, o = gk * p[k + h / 2] % Mod;
				p[k] = (e + o) % Mod; p[k + h / 2] = ((e - o) % Mod + Mod) % Mod;
				gk = gk * gn % Mod;
			}
		}
	}
	if(type == 1) {
		LL invl = qpow(len, Mod - 2);
		for(int i = 0; i < len; ++i) p[i] = p[i] * invl % Mod;
	}
}

LL tmp[Maxn];
void polyinv(LL A[], LL B[], int siz) {
	if(siz == 1) {B[0] = qpow(A[0], Mod - 2); return;}
	polyinv(A, B, (siz + 1) >> 1);
	int len = 1, L = (siz << 1); while(L) L >>= 1, len <<= 1;
	for(int i = 0; i < siz; ++i) tmp[i] = A[i];
	for(int i = siz; i < len; ++i) tmp[i] = 0;
	Setrev(len); ntt(tmp, len, 0); ntt(B, len, 0);
	for(int i = 0; i < len; ++i) B[i] = ((2ll * B[i] % Mod - B[i] * B[i] % Mod * tmp[i] % Mod) % Mod + Mod) % Mod;
	ntt(B, len, 1);
	for(int i = siz; i < len; ++i) B[i] = 0;
}

实现上的一些小细节

  • 注意多项式长度,在算法没有问题的时候,长度稍微长了些并不会影响多项式求逆的结果。

  • 最后那一步记得把 B 数组无用的元素清空。

  • 虽然看上去用了多次 NTT,但是根据主定理(如有需要请自行搜索),复杂度仍旧是 \(\mathcal O\left(n \log n\right)\) 的。

完整代码

洛谷 P4238 多项式乘法逆

#include <bits/stdc++.h>

#define LL long long

using namespace std;

template <typename Temp> inline void read(Temp & res) {
	Temp fh = 1; res = 0; char ch = getchar();
	for(; !isdigit(ch); ch = getchar()) if(ch == '-') fh = -1;
	for(; isdigit(ch); ch = getchar()) res = (res << 3) + (res << 1) + (ch ^ '0');
	res = res * fh;
}

const int Maxn = 262200;
const LL Mod = 998244353, g[2] = {3, 332748118};

LL qpow(LL A, LL P) {
	LL res = 1;
	while(P) {
		if(P & 1) res = res * A % Mod;
		A = A * A % Mod;
		P >>= 1;
	}
	return res;
} 

namespace Polynomial {
	int rev[Maxn];
	void Setrev(int len) {
		for(int i = 1; i < len; ++i) {
			rev[i] = rev[i >> 1] >> 1;
			if(i & 1) rev[i] |= (len >> 1);
		}
	}
	void ntt(LL p[], int len, int type) {
		for(int i = 0; i < len; ++i) if(i < rev[i]) swap(p[i], p[rev[i]]); 
		for(int h = 2; h <= len; h <<= 1) {
			LL gn = qpow(g[type], (Mod - 1) / h);
			for(int j = 0; j < len; j += h) {
				LL gk = 1;
				for(int k = j; k < j + h / 2; ++k) {
					LL e = p[k] % Mod, o = gk * p[k + h / 2] % Mod;
					p[k] = (e + o) % Mod; p[k + h / 2] = ((e - o) % Mod + Mod) % Mod;
					gk = gk * gn % Mod;
				}
			}
		}
		if(type == 1) {
			LL invl = qpow(len, Mod - 2);
			for(int i = 0; i < len; ++i) p[i] = p[i] * invl % Mod;
		}
	}
	
	void polymul(LL A[], LL B[], int siz) {
		int len = 1; while(siz) siz >>= 1, len <<= 1;
		Setrev(len); ntt(A, len, 0); ntt(B, len, 0);
		for(int i = 0; i < len; ++i) A[i] = A[i] * B[i] % Mod;
		ntt(A, len, 1);
	}
	
	LL tmp[Maxn];
	void polyinv(LL A[], LL B[], int siz) {
		if(siz == 1) {B[0] = qpow(A[0], Mod - 2); return;}
		polyinv(A, B, (siz + 1) >> 1);
		int len = 1, L = (siz << 1); while(L) L >>= 1, len <<= 1;
		for(int i = 0; i < siz; ++i) tmp[i] = A[i];
		for(int i = siz; i < len; ++i) tmp[i] = 0;
		Setrev(len); ntt(tmp, len, 0); ntt(B, len, 0);
		for(int i = 0; i < len; ++i) B[i] = ((2ll * B[i] % Mod - B[i] * B[i] % Mod * tmp[i] % Mod) % Mod + Mod) % Mod;
		ntt(B, len, 1);
		for(int i = siz; i < len; ++i) B[i] = 0;
	}
}

int n, m;
LL a[Maxn], b[Maxn];

int main() {
	read(n);
	for(int i = 0; i < n; ++i) read(a[i]);
	Polynomial::polyinv(a, b, n);
	for(int i = 0; i < n; ++i) printf("%lld ", b[i]);
	return 0;
}
posted @ 2021-03-13 14:19  zimujun  阅读(263)  评论(1编辑  收藏  举报