【知识总结】多项式全家桶(一)(NTT、加减乘除和求逆)

我这种数学一窍不通的菜鸡终于开始学多项式全家桶了……

必须要会的前置技能:FFT(不会?戳我:【知识总结】快速傅里叶变换(FFT)

以下无特殊说明的情况下,多项式的长度指多项式最高次项的次数加\(1\)

一、NTT

跟FFT功能差不多,只是把复数域变成了模域(计算复数系数多项式相乘变成计算在模意义下整数系数多项式相乘)。你看FFT里的单位圆是循环的,模一个质数也是循环的嘛qwq。\(n\)次单位根\(w_n\)怎么搞?看这里:【BZOJ3328】PYXFIB(数学)(内含相关证明。只看与原根和单位根相关的内容即可。)

注意裸的NTT要求模数\(p\)存在原根并且\(p-1\)\(2\)的若干次幂的倍数(这个幂要大于多项式次数\(n\))。于是通常就会用著名的NTT模数:\(998244353=2^{23}\times 7\times 17+1\)

节约篇幅,代码先不放了。后面所有代码里都有NTT模板……

二、多项式求逆

对于\(n\)次多项式\(A\),如果有多项式\(B\)满足\(AB\equiv 1 \mod x^{n+1}\),则称\(B\)\(A\)在模\(x^{n+1}\)意义下的逆元(和整数逆元差不多)。通常采用倍增的方法求逆元。通常都会规定多项式系数在模\(p\)的意义下。

首先,\(A\)在模\(x\)的意义下就只有一个常数项,所以此时的逆元\(B\)也只有一个常数项,就是\(A\)的常数项模\(p\)的逆元。

如果我们知道\(B_0\)\(A\)在模\(x^{\lceil\frac{n}{2}\rceil}\)意义下的逆元,现在要求\(B\)\(A\)在模\(x^n\)意义下的逆元。根据题设,显然有:

\[AB=1\mod x^n \]

很明显,\(AB\)\(1\)\(n-1\)次项系数全是\(0\),所以模一个\(x\)的低于\(n\)次幂也一定是\(1\)。所以

\[AB_0=AB=1\mod x^{\lceil\frac{n}{2}\rceil} \]

那么

\[B-B_0=0\mod x^{\lceil\frac{n}{2}\rceil} \]

两边和模数同时平方:

\[B^2+B_0^2-2BB_0=0\mod x^n \]

两边同时乘\(A\),得到(别忘了\(AB=1\mod x^n\)):

\[B+AB_0^2-2B_0=0\mod x^n \]

然后移项,得到:

\[B=2B_0-AB_0^2\mod x^n \]

照着这个式子递归算就行了。

由于后面带余除法的代码包含求逆,所以代码同样略去……

三、加减乘除

加减法:直接每项对应相加减。

乘法:这就是NTT的目的啊喂!

除法:如果不是带余除法直接乘逆元。下面着重介绍带余除法。

已知\(n-1\)次多项式\(F\)\(m-1\)次多项式\(G\),求\(n-m\)次多项式\(Q\)和多项式\(R\)\(R\)的次数小于\(m-1\)),满足:

\[F(x)=Q(x)G(x)+R(x) \mod x^n \]

很明显,主要的难点在于式子里有个叫做\(R\)的嘴子(兔崽子Tzz)。如果能把它搞掉该多好……

注意到\(R\)的次数小于\(m-1\),那么我们把它翻转,末尾补\(0\),是不是就可以把它模成\(0\)了?定义\(\mathrm{Tzz}_{A,n}\)表示把\(A\)视作一个长为\(n\)的多项式(高次项补\(0\))后翻转的结果。即\(\mathrm{Tzz}_{A,n}(x)=x^{n-1}A(\frac{1}{x})=\sum\limits_{i=0}^{n-1}a_ix^{n-i-1}\)

\(F=QG+R\)的每个多项式都代入同一个数,这个多项式也一定是成立的。所以:

\[F(\frac{1}{x})=Q(\frac{1}{x})G(\frac{1}{x})+R(\frac{1}{x}) \]

两边同乘\(x^{n-1}\),得到:

\[x^{n-1}F(\frac{1}{x})=x^{n-m}Q(\frac{1}{x})\cdot x^{m-1}G(\frac{1}{x})+x^{n-1}R(\frac{1}{x}) \]

\[\mathrm{Tzz}_{F,n}=\mathrm{Tzz}_{Q,n-m+1}\mathrm{Tzz}_{G,m}+\mathrm{Tzz}_{R,n} \]

现在\(\mathrm{Tzz}_{R,n}\)的最高次项是\(n-1\),但是从常数项到\(n-m\)次项全是\(0\)(因为\(R\)的长度最多就是\(m-1\))。所以现在如果模\(n-m+1\),那么\(\mathrm{Tzz}_{R,n}\)就是\(0\)了,而\(\mathrm{Tzz}_{Q,n-m+1}\)因为最高次是\(n-m\)所以不会受到影响。

于是用\(\mathrm{Tzz}_{F,n}\)乘上\(\mathrm{Tzz}_{G,m}\)的逆元就是\(\mathrm{Tzz}_{Q,n-m+1}\),翻回去就能得到\(Q\)

最后把\(Q\)代进原式,乘一乘减一减就能算出\(R\)

所以这样为什么是对的?(以下“低次项”指翻转后的\(n-m\)项,“高次项”指翻转后的\(m\)项)首先在模\(x^{n-m+1}\)意义下肯定能保证低次项是对的(即\(\mathrm{Tzz}{F,n}\)\(\mathrm{Tzz}_{G,m}\mathrm{Tzz}_{Q,n-m+1}\)的前\(n-m\)项相等)。至于高次项,反正有\(\mathrm{Tzz}_{R,n}\)来补锅,所以即使不对也没关系。

完结撒花。

下一篇:【知识总结】多项式全家桶(二)(ln和exp)

代码:洛谷4512

注意NTT的数组一定要保证多余的元素全部是\(0\)

代码开头的#undef是防机惨护身符。

(我脑子有病啊求原根全是手写的……

#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cctype>
#undef i
#undef j
#undef k
#undef min
#undef max
#undef swap
#undef sort
#undef for
#undef while
#undef if
#undef true
#undef false
#undef printf
#undef scanf
#undef getchar
#undef putchar
#define _ 0
using namespace std;

namespace zyt
{
	template<typename T>
	inline bool read(T &x)
	{
		char c;
		bool f = false;
		x = 0;
		do
			c = getchar();
		while (c != EOF && c != '-' && !isdigit(c));
		if (c == EOF)
			return false;
		if (c == '-')
			f = true, c = getchar();
		do
			x = x * 10 + c - '0', c = getchar();
		while (isdigit(c));
		if (f)
			x = -x;
		return true;
	}
	template<typename T>
	inline void write(T x)
	{
		static char buf[20];
		char *pos = buf;
		if (x < 0)
			putchar('-'), x = -x;
		do
			*pos++ = x % 10 + '0';
		while (x /= 10);
		while (pos > buf)
			putchar(*--pos);
	}
	typedef long long ll;
	const int N = 1e5 + 10, LEN = (N << 2), p = 998244353;
	namespace Polynomial
	{
		inline int power(int a, int b)
		{
			a %= p, b %= p - 1;
			int ans = 1;
			while (b)
			{
				if (b & 1)
					ans = (ll)ans * a % p;
				a = (ll)a * a % p;
				b >>= 1;
			}
			return ans;
		}
		inline int inv(const int a)
		{
			return power(a, p - 2);
		}
		namespace Primitive_Root
		{
			pair<int, int> prime[20];
			int cnt;
			void get_prime(int n)
			{
				cnt = 0;
				for (int i = 2; i * i <= n; i++)
				{
					if (n % i == 0)
						prime[cnt++] = make_pair(i, 0);
					while (n % i == 0)
						++prime[cnt - 1].second, n /= i;
				}
			}
			int get_g(const int n)
			{
				get_prime(n - 1);
				for (int i = 2; i < n; i++)
				{
					bool flag = true;
					for (int j = 0; j < cnt && flag; j++)
						flag &= (power(i, (n - 1) / prime[j].first) != 1);
					if (flag)
						return i;
				}
				return -1;
			}
		}
		int omega[LEN], winv[LEN], rev[LEN];
		void init(const int n, const int lg2)
		{
			static int g = 0;
			if (!g)
				g = Primitive_Root::get_g(p);
			int w = power(g, (p - 1) / n), wi = inv(w);
			omega[0] = winv[0] = 1;
			for (int i = 1; i < n; i++)
			{
				omega[i] = (ll)omega[i - 1] * w % p;
				winv[i] = (ll)winv[i - 1] * wi % p;
			}
			for (int i = 0; i < n; i++)
				rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (lg2 - 1)));
		}
		void ntt(int *a, const int *w, const int n)
		{
			for (int i = 0; i < n; i++)
				if (i < rev[i])
					swap(a[i], a[rev[i]]);
			for (int l = 1; l < n; l <<= 1)
				for (int i = 0; i < n; i += (l << 1))
					for (int k = 0; k < l; k++)
					{
						int tmp = (a[i + k] - (ll)w[n / (l << 1) * k] * a[i + l + k] % p + p) % p;
						a[i + k] = (a[i + k] + (ll)w[n / (l << 1) * k] * a[i + l + k] % p) % p;
						a[i + l + k] = tmp;
					}
		}
		void reverse(int *a, const int n)
		{
			static int tmp[LEN];
			memcpy(tmp, a, sizeof(int[n]));
			for (int i = 0; i < n; i++)
				a[i] = tmp[n - i - 1];
		}
		inline void plus(const int *a, const int *b, int *c, const int n)
		{
			for (int i = 0; i < n; i++)
				c[i] = (a[i] + b[i]) % p;
		}
		inline void minus(const int *a, const int *b, int *c, const int n)
		{
			for (int i = 0; i < n; i++)
				c[i] = (a[i] - b[i] + p) % p;
		}
		void _inv(const int *a, int *b, const int n)
		{
			if (n == 1)
				b[0] = inv(a[0]);
			else
			{
				static int tmp[LEN];
				_inv(a, b, (n + 1) >> 1);
				int m = 1, lg2 = 0;
				while (m < (n << 1) - 1)
					m <<= 1, ++lg2;
				memcpy(tmp, a, sizeof(int[n]));
				memset(tmp + n, 0, sizeof(int[m - n]));
				memset(b + ((n + 1) >> 1), 0, sizeof(int[m - ((n + 1) >> 1)]));
				init(m, lg2);
				ntt(tmp, omega, m);
				ntt(b, omega, m);
				for (int i = 0; i < m; i++)
					b[i] = (b[i] * 2LL % p - (ll)tmp[i] * b[i] % p * b[i] % p + p) % p;
				ntt(b, winv, m);
				int invm = inv(m);
				for (int i = 0; i < m; i++)
					b[i] = (ll)b[i] * invm % p;
				memset(b + n, 0, sizeof(int[m - n]));
			}
		}
		void inv(const int *a, int *b, const int n)
		{
			static int tmp[LEN];
			memcpy(tmp, a, sizeof(int[n]));
			_inv(tmp, b, n);
		}
		void mul(const int *a, const int *b, int *c, const int n)
		{
			int m = 1, lg2 = 0;
			while (m < (n << 1))
				m <<= 1, ++lg2;
			static int x[LEN], y[LEN];
			memcpy(x, a, sizeof(int[n]));
			memset(x + n, 0, sizeof(int[m - n]));
			memcpy(y, b, sizeof(int[n]));
			memset(y + n, 0, sizeof(int[m - n]));
			init(m, lg2);
			ntt(x, omega, m);
			ntt(y, omega, m);
			for (int i = 0; i < m; i++)
				x[i] = (ll)x[i] * y[i] % p;
			ntt(x, winv, m);
			int invm = inv(m);
			for (int i = 0; i < m; i++)
				x[i] = (ll)x[i] * invm % p;
			memcpy(c, x, sizeof(int[n]));
		}
		void div(const int *_F, const int *_G, int *_Q, int *_R, const int n, const int m)
		{
			static int F[LEN], G[LEN], invG[LEN], Q[LEN], R[LEN];
			memcpy(F, _F, sizeof(int[n]));
			memcpy(G, _G, sizeof(int[m]));
			reverse(F, n), reverse(G, m);
			if (m < n - m + 1)
				memset(G + m, 0, sizeof(int[n - m + 1 - m]));
			inv(G, invG, n - m + 1);
			mul(F, invG, Q, n - m + 1);
			reverse(F, n), reverse(G, m), reverse(Q, n - m + 1);
			mul(G, Q, G, n);
			minus(F, G, R, n);
			memcpy(_Q, Q, sizeof(int[n - m + 1]));
			memcpy(_R, R, sizeof(int[m]));
		}
	}
	int F[LEN], G[LEN], Q[LEN], R[LEN];
	int work()
	{
		int n, m;
		read(n), read(m);
		++n, ++m;
		for (int i = 0; i < n; i++)
			read(F[i]);
		for (int i = 0; i < m; i++)
			read(G[i]);
		Polynomial::div(F, G, Q, R, n, m);
		for (int i = 0; i < n - m + 1; i++)
			write(Q[i]), putchar(' ');
		putchar('\n');
		for (int i = 0; i < m - 1; i++)
			write(R[i]), putchar(' ');
		return (0^_^0);
	}
}
int main()
{
	return zyt::work();
}
posted @ 2019-01-06 00:38  Inspector_Javert  阅读(874)  评论(0编辑  收藏  举报