【知识总结】多项式全家桶(三)(任意模数NTT)

经过两个月的咕咕,“多项式全家桶” 系列终于迎来了第三期……(雾)

上一篇:【知识总结】多项式全家桶(二)(ln和exp)

先膜拜(伏地膜)大恐龙的博客:任意模数 NTT

在页面右侧面板 “您想嘴谁” 中选择 “大恐龙” 就可以在页面左下角戳她哦

首先务必先学会 NTT (如果不会,请看多项式全家桶(一)),并充分理解中国剩余定理……

之前提到了,普通 NTT 的模数必须是一个质数,且这个质数中必须有一个足够大的 \(2\) 的幂作为因子。然而,毕竟满足这样条件的数不多,如果题目给定的模数不满足这样的条件呢?当然是大骂出题人毒瘤简单来说,就是用三个满足条件的模数 (常用的有 \(469762049\)\(998244353\)\(1004535809\)。最后将介绍如何背过它们)分别做 NTT ,然后用中国剩余定理(下文缩写作 CRT )合并。因为如果所有系数都不超过 \(10^9\) ,项数是 \(10^5\) ,那么最终每项系数(不取模)最大是 \(10^9\times 10^9 \times 10^5=10^{23}\)。而 CRT 合并后的数是模上述三个质数乘积(大约 \(10^{26}\) 这个数量级)的,所以把 CRT 合并后的答案模题目给定的模数即可。

然而很明显 long long 存不下啊……于是 CRT 的时候有技巧。

很明显 CRT 的时候各项独立,于是每一项分别处理。设这一项答案是 \(x\) ,上述三个质数是 \(p_0\)\(p_1\)\(p_2\) ,满足:

\[\begin{cases} x=a_0\mod p_0 \\ x=a_1 \mod p_1\\ x=a_2 \mod p_2\\ \end{cases}\]

暴力合并是会爆 long long 的,不过可以先合并前两项( \(p_0\times p_1<2^{63}\)),得到(\(\mathrm{inv}(x,y)\) 表示 \(x\) 在模 \(y\) 意义下的逆元):

\[x=a_0p_1\mathrm{inv}(p_1, p_0)+a_1p_0\mathrm{inv}(p_0, p_1)\mod p_0p_1 \]

把右边那一堆设为 \(t_0\) (模仿恐龙的变量名 qwq ),设答案为 \(t_1p_0p_1+t_0\) ,则有:

\[t_1p_0p_1+t_0=a_2\mod p_2 \]

即:

\[t_1=(a_2-t_0)\mathrm{inv}(p_0p_1, p_2)\mod p_2 \]

因为 \(x<p_0p_1p_2\) ,所以 \(t_1<p_2\) ,所以求出来直接就是真正的 \(t_1\) 。然后在模题目给定的模数意义下算 \(t_1p_0p_1+t_0\) 就行了。

下面介绍如何记住那三个模数,记忆力好的可以直接跳过

我是在机房里以被兔崽子揍一顿的代价把这三个数字吼了一上午并且用安徽黄梅戏《天仙配》的曲调唱了一晚上才背过

先看一眼要记住的东西:

469762049、998244353、1004535809

记不住?跟我大声念:

肆陆玖柒陆贰零肆玖!

玖玖捌贰肆肆叁伍叁!

壹零零肆伍叁伍捌零玖!

没记住?再来一遍:

肆陆玖柒陆贰零肆玖!

玖玖捌贰肆肆叁伍叁!

壹零零肆伍叁伍捌零玖!

还没记住?再吼十遍

在页面右侧面板点一下 “运动员进行曲” ,然后跟着唱:

(前奏:当当当当当当当当当当当当当!)

肆陆玖柒啊贰零~~肆玖~~

陆陆玖柒陆啊肆玖~

壹零零零零肆伍叁伍~捌!零!玖!

玖玖捌贰肆肆肆肆肆叁叁叁~叁伍叁!

(以上重复一遍)

肆陆柒陆零肆玖~~

玖玖肆肆叁伍叁~~

壹零零~啊肆伍叁伍捌~零玖

玖玖捌贰肆肆叁伍叁~~

(以上重复一遍)

代码:

注意在模 \(p_0p_1\) 意义下算乘法的时候要用龟速乘防止爆 long long 。这里介绍一种从恐龙那里学来的\(O(1)\) 龟速乘:

inline ll mul(ll a, ll b, const ll p)
{
	const static long double eps = 1e-15;
	a %= p, b %= p;
	return ((a * b - ll((long double)a / p * b + eps) * p) % p + p) % p;
}

原理是利用了 long long 溢出后不是随机数,而只是真实值减去了 \(2^{64}\) ,所以 \(a\)\(\lfloor\frac{a}{b}\rfloor\cdot b\) 之差仍然是 \(a\mod b\)

下一篇:【知识总结】多项式全家桶(四)(快速幂和开根)

完整代码(洛谷4245):

#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cctype>
using namespace std;

namespace zyt
{
	template<typename T>
	inline bool read(T &x)
	{
		char c;
		bool f = false;
		x = 0;
		do
			c = getchar();
		while (c != EOF && c != '-' && !isdigit(c));
		if (c == EOF)
			return false;
		if (c == '-')
			f = true, c = getchar();		
		do
			x = x * 10 + c - '0', c = getchar();
		while (isdigit(c));
		if (f)
			x = -x;
		return true;
	}
	template<typename T>
	inline void write(T x)
	{
		static char buf[20];
		char *pos = buf;
		if (x < 0)
			putchar('-'), x = -x;
		do
			*pos++ = x % 10 + '0';
		while (x /= 10);
		while (pos > buf)
			putchar(*--pos);
	}
	typedef long long ll;
	const int N = 1e5 + 10, LEN = N << 2, p[] = {469762049, 998244353, 1004535809};
	int n;
	inline ll mul(ll a, ll b, const ll p)
	{
		const static long double eps = 1e-15;
		a %= p, b %= p;
		return ((a * b - ll((long double)a / p * b + eps) * p) % p + p) % p;
	}
	inline ll power(ll a, ll b, const ll p)
	{
		ll ans = 1;
		while (b)
		{
			if (b & 1)
				ans = (ll)ans * a % p;
			a = (ll)a * a % p;
			b >>= 1;
		}
		return ans;
	}
	ll inv(const ll a, const ll p)
	{
		return power(a % p, p - 2, p);
	}
	namespace Polynomial
	{
		int rev[LEN], omega[LEN], winv[LEN], p;
		void ntt(int *a, const int *w, const int n)
		{
			for (int i = 0; i < n; i++)
				if (i < rev[i])
					swap(a[i], a[rev[i]]);
			for (int l = 1; l < n; l <<= 1)
				for (int i = 0; i < n; i += (l << 1))
					for (int k = 0; k < l; k++)
					{
						int x = a[i + k], y = (ll)w[n / (l << 1) * k] * a[i + l + k] % p;
						a[i + k] = (x + y) % p;
						a[i + l + k] = (x - y + p) % p;
					}
		}
		void init(const int n, const int lg2)
		{
			const static int g = 3;
			int w = power(g, (p - 1) / n, p), wi = inv(w, p);
			omega[0] = winv[0] = 1;
			for (int i = 1; i < n; i++)
			{
				omega[i] = (ll)omega[i - 1] * w % p;
				winv[i] = (ll)winv[i - 1] * wi % p;
			}
			for (int i = 0; i < n; i++)
				rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lg2 - 1));
		}
		void mul(const int *a, const int *b, int *c, const int n, const int m, const int _p)
		{
			static int x[LEN], y[LEN];
			p = _p;
			int len = 1, lg2 = 0;
			while (len < (n + m - 1))
				len <<= 1, ++lg2;
			memcpy(x, a, sizeof(int[n]));
			memset(x + n, 0, sizeof(int[len - n]));
			memcpy(y, b, sizeof(int[m]));
			memset(y + m, 0, sizeof(int[len - m]));
			init(len, lg2);
			ntt(x, omega, len), ntt(y, omega, len);
			for (int i = 0; i < len; i++)
				x[i] = (ll)x[i] * y[i] % p;
			ntt(x, winv, len);
			int invlen = inv(len, p);
			for (int i = 0; i < len; i++)
				c[i] = (ll)x[i] * invlen % p;
		}
	}
	int A[N], B[N], ans[4][LEN];
	void CRT(const int n, const int mod)
	{
		ll p01 = (ll)p[0] * p[1];
		ll m0 = mul(p[1], inv(p[1], p[0]), p01), m1 = mul(p[0], inv(p[0], p[1]), p01);
		ll inv2 = inv(p01, p[2]);
		for (int i = 0; i < n; i++)
		{
			ll t0 = (mul(ans[0][i], m0, p01) + mul(ans[1][i], m1, p01)) % p01;
			ll t1 = (ans[2][i] - t0 % p[2] + p[2]) % p[2] * inv2 % p[2];
			ans[3][i] = (t1 % mod * (p01 % mod) % mod + t0) % mod;
		}
	}
	int work()
	{
		int n, m, P;
		read(n), read(m), read(P);
		++n, ++m;
		for (int i = 0; i < n; i++)
			read(A[i]);
		for (int i = 0; i < m; i++)
			read(B[i]);
		for (int i = 0; i < 3; i++)
			Polynomial::mul(A, B, ans[i], n, m, p[i]);
		CRT(n + m - 1, P);
		for (int i = 0; i < n + m - 1; i++)
			write(ans[3][i]), putchar(' ');
		return 0;
	}
}
int main()
{
	freopen("4245.in", "r", stdin);
	return zyt::work();
}
posted @ 2019-03-13 15:31  Inspector_Javert  阅读(753)  评论(0编辑  收藏  举报