多项式除法及取模

多项式除法及取模

http://blog.miskcoo.com/2015/05/polynomial-division

概述

给出一个 \(n\) 次多项式 \(A(x)\) 和一个 \(m(m \le n)\) 次多项式 \(B(x)\) ,要求求出两个多项式 \(D(x), R(x)\) , 满足

\[A(x) = D(x)B(x) + R(x) \]

其中 \(degD \le degA - degB = n - m, degR < m\)

可以在 \(O(n \log n)\) 的时间求解

原理

首先,我们先想办法消除 \(R(x)\) 的影响,我们定义

\[A^R(x) = x^nA(\dfrac 1x) \]

实际上就是将 \(A(x)\) 的系数翻转, 例如

\[A(x) = x^3 + 2x^2 + 4x + 1 \\ A^R(x) = x^3(x^{-3} + 2x^{-2} + 4x^{-1} + 1) = 1 + 2x + 4x^2 + x^3 \]

接下来,我们将 \((1)\) 中的 \(x\)\(\dfrac 1x\) 替换,并在两边同乘 \(x^n\)

\[x^nA(\dfrac 1x) = x^{n-m}D(\dfrac 1x) x^m B(\dfrac 1x) + x^{n - m + 1}x^{m - 1} R(\dfrac 1x) \\ A^R(x) = D^R(x) B^R(x) + x^{n - m +1} R^R(x) \]

观察发现此时 \(x^{n - m + 1}R^R(x)\) 的非零项都在 \(n - m + 1\) 上,而 \(D^R(x)\) 的最高次项为 \(n - m\) , 那么我们有

\[A^R(x) \equiv D^R(x)B^R(x) (mod \; x^{n - m + 1}) \]

那么我们就可以用一次求出逆元的复杂度求出 \(D(x)\) ,再代回原式就可以得到 \(R(x)\)

Code

  1. \(B\) 翻转,求出其在 \(mod \; x^{n - m - 1}\) 意义下的逆元 \(B'^R\)
  2. \(A\) 翻转, 得到 \(D^R(x) = A^R(x) \cdot B'^R(x)\)
  3. 求出 \(R(x) = A(x) - D(x)B(x)\)

不打换行真的会死.......

洛谷 P4512

#include <algorithm>
#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
inline char nc() {
	static char buf[100000], *l = buf, *r = buf;
	return l==r&&(r=(l=buf)+fread(buf,1,100000,stdin),l==r)?EOF:*l++;
}
template<class T> void read(T &x) {
	x = 0; int f = 1, ch = nc();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=nc();}
	while(ch>='0'&&ch<='9'){x=x*10-'0'+ch;ch=nc();}
	x *= f;
}
#define inv(a) quick_power(a, mod - 2)
typedef long long ll;
const int g = 3;
const int mod = 998244353;
const int phi = mod - 1;
const int maxn = 400000 + 5;
const int maxlog = 20;
int n, m; ll F[maxn], G[maxn], D[maxn], R[maxn];
inline ll sum(ll x) {
	return x >= mod ? x - mod : x;
}
inline ll dec(ll x) {
	return x < 0 ? x + mod : x;
}
ll quick_power(ll x, ll y) {
	ll re = 1;
	while(y) {
		if(y & 1) re = re * x % mod;
		x = x * x % mod;
		y >>= 1;
	}
	return re;
}
namespace polynomial {
	int rev[maxn];
	ll w[maxlog][maxn][2];
	void init() {
		for(int i = 1, s = 0; i < maxn; i <<= 1, ++s) {
			ll wn0 = quick_power(g, phi / (i * 2));
			ll wn1 = quick_power(g, -phi / (i * 2) + phi);
			w[s][0][0] = w[s][0][1] = 1;
			for(int k = 1; k < i; ++k) {
				w[s][k][0] = w[s][k - 1][0] * wn0 % mod;
				w[s][k][1] = w[s][k - 1][1] * wn1 % mod;
			}
		}
	}
	void init_rev(int n, int L) {
		for(int i = 1; i < n; ++i) {
			rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (L - 1));
		}
	}
	void FFT(ll *A, int n, int f) {
		int d = f == -1;
		for(int i = 0; i < n; ++i) if(i > rev[i]) {
			swap(A[i], A[rev[i]]);
		}
		for(int i = 1, s = 0; i < n; i <<= 1, ++s) {
			for(int j = 0, p = i << 1; j < n; j += p) {
				ll *u = A + j, *v = A + j + i;
				for(int k = 0; k < i; ++k, ++u, ++v) {
					ll x = *u, y = *v * w[s][k][d] % mod;
					*u = sum(x + y);
					*v = dec(x - y);
				}
			}
		}
		if(f == -1) {
			ll r = inv(n);
			for(int i = 0; i < n; ++i) {
				A[i] = A[i] * r % mod;
			}
		}
	}
	void inverse(int step, ll *A, ll *B) {
		static ll T[maxn];
		
		if(step == 1) {
			B[0] = inv(A[0]);
			return;
		}
		
		inverse((step + 1) >> 1, A, B);
		
		int n = 1, L = 0, m = step << 1;
		for(n = 1; n <= m; n <<= 1) ++L;
		init_rev(n, L);
		
		copy(A, A + step, T), fill(T + step, T + n, 0);

		FFT(T, n, 1);
		FFT(B, n, 1);
		for(int i = 0; i < n; ++i) {
			B[i] = dec(2 - T[i] * B[i] % mod) * B[i] % mod;
		}
		FFT(B, n, -1);
		
		fill(B + step, B + n, 0);
	}
	void division(ll *A, ll *B, int n, int m, ll *D, ll *R) {
		static ll A0[maxn], B0[maxn];
		memset(A0, 0, sizeof(A0));
		memset(B0, 0, sizeof(B0));

		int len = n - m + 1;

		reverse_copy(B, B + m + 1, A0);
		inverse(len, A0, B0);

		int p, L = 0, tmp = len << 1;
		for(p = 1; p <= tmp; p <<= 1) ++L;
		init_rev(p, L);
		
		reverse_copy(A, A + n + 1, A0), fill(A0 + len, A0 + p, 0);
		
		FFT(A0, p, 1);
		FFT(B0, p, 1);
		for(int i = 0; i < p; ++i) {
			A0[i] = A0[i] * B0[i] % mod;
		}
		FFT(A0, p, -1);
		
		reverse(A0, A0 + len);
		copy(A0, A0 + len, D);
		
		L = 0, tmp = n;
		for(p = 1; p <= tmp; p <<= 1) ++L;
		init_rev(p, L);
		
		copy(B, B + m + 1, B0), fill(B0 + m + 1, B0 + p, 0);
		fill(A0 + len, A0 + p, 0);
		
		FFT(A0, p, 1);
		FFT(B0, p, 1);
		for(int i = 0; i < p; ++i) {
			A0[i] = A0[i] * B0[i] % mod;
		}
		FFT(A0, p, -1);
		
		for(int i = 0; i < m; ++i) {
			R[i] = dec(A[i] - A0[i]);
		}
	}
}
int main() {
//	freopen("testdata.in", "r", stdin);
	read(n), read(m);
	for(int i = 0; i <= n; ++i) read(F[i]);
	for(int i = 0; i <= m; ++i) read(G[i]);

	polynomial :: init();
	polynomial :: division(F, G, n, m, D, R);
	
	for(int i = 0; i <= n - m; ++i) {
		if(i) printf(" ");
		printf("%lld", D[i]);
	}
	printf("\n");
	for(int i = 0; i < m; ++i) {
		if(i) printf(" ");
		printf("%lld", R[i]);
	}
	printf("\n");
	
	return 0;
}
posted @ 2020-05-18 11:25  LJZ_C  阅读(298)  评论(0编辑  收藏  举报