多项式乘法

多项式乘法

和卷积

\(h=f\times g\),则 \(h_n=\sum_{i=0}^n f_ig_{n-i}\)

朴素的乘法是 \(O(n^2)\) 的,看起来到了极限

但怎样再快一些?

发现如果把多项式换成点值表示,则 \(O(n)\) 即可计算!

问题是,用系数表示的次数为 \(n\) 的多项式化成点值就要 \(O(n^2)\)

于是引入了 系数 \(\to\) 点值 \(\to\) 系数的各种多项式乘法算法

FFT

前置知识:复数

DFT:通过系数算出点值

想快速计算点值,考虑取各种特殊点

但实数域内的特殊数有限,考虑扩展到复数域

对于次数为 \(n-1\) 的多项式,代入 \(n\) 个点值

\(\omega_n^{0\sim n-1}\),即单位复根

复平面上,它的终点把单位圆等分为 \(n\) 等份

它的性质:(可以用单位圆的图示理解)

  • \(\omega_n^0=\omega_n^n=1\) 在实轴上

  • \(\omega_n^k=-\omega_n^{k+\frac n 2}\) 关于原点对称

  • \(\omega_n^k=\omega_n^{n+k}\) 转了一圈

  • \(\omega_n^k=\omega_{2n}^{2k}\) 辐角一样

然后,把多项式奇偶次项分开

\(f(x)=a_0x^0+a_2x^2+a_4x^4\dots+a_1x_1+a_3x^3+a_5x^5\)

\(g(x)=a_0x^0+a_2x^1+a_4x^2\dots\)\(h(x)=a_1x^0+a_3x^1+a_5x^2\)

\(f(x)=g(x^2)+x\times h(x^2)\)

发现多项式的次数仅为原来的 \(\frac 1 2\)

代入 \(n\) 个单位复根

\(f(\omega_n^i)=g((\omega_n^i)^2)+\omega_n^i\times h((\omega_n^i)^2)\)

\(f(\omega_n^{i+\frac n 2})=g((\omega_n^{i+\frac n 2})^2)+\omega_n^{i+\frac n 2}\times h((\omega_n^{i+\frac n 2})^2)=g((\omega_n^i)^2)-\omega_n^i\times h((\omega_n^i)^2)\)

发现这两个式子只有一个常数项不同,于是只计算前面一半可以直接得到后面

那么,分成两半递归分治计算 \(f\),复杂度是 \(O(n\log n)\)

IDFT:点值重新变为系数

证明看不会了……其实是懒得看,况且不知道好像没关系

感性理解一下,“相反”,把单位复根全部变为它的共轭复数即可

结论有:第 \(i\) 项最后的实数部分为 \(c_i\),则系数 \(a_i=\frac{c_i}n\)

注意实现时拿 STL 中的 complex<double>,用 real() 取出实数部分

因为每次分治,所以次数应补全为 \(lim=2^k\)

递归版代码:

void fft(comp *f, ll n, ll typ)
{
	if(n == 1)	return;
	for(ll i = 0; i < n; ++i)	tmp[i] = f[i];
	for(ll i = 0; i < n; ++i) // 奇偶分开
		if(i & 1)	f[(i >> 1) + (n >> 1)] = tmp[i];
		else	f[i >> 1] = tmp[i];
	comp *g = f, *h = f + n / 2;
	fft(g, n >> 1, typ), fft(h, n >> 1, typ); // 递归计算
	comp cur(1, 0), w(cos((double)2.0 * pi / n), sin((double)2.0 * pi / n) * typ);
	for(ll i = 0; i < (n >> 1); ++i, cur *= w) // 算出当前的值
		tmp[i] = g[i] + h[i] * cur, tmp[i + (n >> 1)] = g[i] - h[i] * cur;
	for(ll i = 0; i < n; ++i)	f[i] = tmp[i]; // 更新
}
int main()
{
	n = read(), m = read();
	for(ll i = 0; i <= n; ++i)	a[i] = {(double)read(), 0.0};
	for(ll i = 0; i <= m; ++i)	b[i] = {(double)read(), 0.0};
	for(; sizn <= n + m; sizn <<= 1);
	fft(a, sizn, 1), fft(b, sizn, 1);
	for(ll i = 0; i <= sizn; ++i)	a[i] = a[i] * b[i];
	fft(a, sizn, -1);
	for(ll i = 0; i <= n + m; ++i)	print((ll)(a[i].real() / sizn + 0.5)), putchar(' ');
	return 0;
}

优化

递归版常数比较大,加上复数运算自带大常数

于是想变成非递归版

一组一组的分治下去,要知道最后一组两两配对的情况

这里把每个数与它的二进制表示反转后表示的数交换(高位补的 0 到前面)

可以递推求出

然后枚举每层的次数,按上述步骤计算即可

void fft(comp *f, ll n, ll typ) // 非递归版,自底向上模拟递归过程 
{
	for(ll i = 0; i < n; ++i)	rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? (n >> 1) : 0);
	for(ll i = 0; i < n; ++i) // 交换顺序 
		if(i < rev[i])	swap(f[i], f[rev[i]]);
	for(ll step = 2; step <= n; step <<= 1) // 当前计算的组内的元素个数 
	{
		ll gap = step >> 1;	
		comp w(cos(pi * (double)2.0 / step), typ * sin(pi * (double)2.0 / step));
		for(ll i = 0; i < n; i += step) // 跳到下一组 
		{
			comp cur(1.0, 0.0);
			for(ll j = i; j < i + gap; ++j, cur *= w)	// 与下层合并 
			{
				comp p = f[j], q = cur * f[j + gap];
				f[j] = p + q, f[j + gap] = p - q;
			}
		}
	}
}
int main()
{
	n = read(), m = read();
	for(ll i = 0; i <= n; ++i)	a[i] = {(double)read(), 0.0};
	for(ll i = 0; i <= m; ++i)	b[i] = {(double)read(), 0.0};
	for(; sizn <= n + m; sizn <<= 1);
	fft(a, sizn, 1), fft(b, sizn, 1);
	for(ll i = 0; i <= sizn; ++i)	a[i] = a[i] * b[i];
	fft(a, sizn, -1);
	for(ll i = 0; i <= n + m; ++i)	print((ll)(a[i].real() / sizn + 0.5)), putchar(' ');
	return 0;
}

NTT

如果多项式系数较大,题目中需要取模,就不好用复数计算

那有什么在模意义下能代替复数呢?

模数为质数,那它的原根 \(g^{p-1}\equiv1(\bmod\ p)\),且 \(g^{0\sim p-2}\) 互不相同

它甚至有与单位复根相似的性质:

  • \(g^n\equiv g^{n+\frac {p-1}2}(\bmod\ p)\)

  • \(g^n\equiv g^{n+p-1}(\bmod\ p)\)

那么,直接把 FFT 中的单位复根换成原根即可

但对模数要求较为苛刻,因为把 \(p-1\) \(2^k\) 等分并不容易

\(998244353=7\times17\times2^{23}+1,g=3\)

\(1004535809=479\times 2^{21}+1,g=3\)

这两个是常见的模数

代码:

void ntt(ll *f, ll n, ll typ)
{
	for(ll i = 0; i < n; ++i)	rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? (n >> 1) : 0);
	for(ll i = 0; i < n; ++i)
		if(i < rev[i])	swap(f[i], f[rev[i]]);
	for(ll step = 2; step <= n; step <<= 1)
	{
		ll gap = step >> 1, w = qmi(g, (mod - 1) / step);
		if(typ < 0)	w = qmi(w, mod - 2);
		for(ll i = 0; i < n; i += step)
		{
			ll cur = 1; 
			for(ll j = i; j < i + gap; ++j, cur = cur * w % mod) 
			{
				ll p = f[j], q = f[j + gap] * cur % mod;
				f[j] = add(p, q), f[j + gap] = add(p, mod - q);
			}
		}
	}
}
int main()
{
	n = read(), m = read();
	for(ll i = 0; i <= n; ++i)	a[i] = read();
	for(ll i = 0; i <= m; ++i)	b[i] = read();
	for(; lim <= n + m; lim <<= 1);
	ntt(a, lim, 1), ntt(b, lim, 1);
	for(ll i = 0; i < lim; ++i)	a[i] = a[i] * b[i] % mod;
	ntt(a, lim, -1);
	inv = qmi(lim, mod - 2);
	for(ll i = 0; i <= n + m; ++i)	print(a[i] * inv % mod), putchar(' ');
	return 0;
}

循环卷积

\(h_k=\sum_{i=0}^n\sum_{j=0}^n [(i+j)\bmod (n+1)=k]f_ig_j\)

求出 \(h'=f\times g\),则

\(h_k=h'_k+h'_{k+n+1}(k=0\sim n-1)\)

\(h_n=h'_n\)

差卷积

\(h_k=\sum_{i=k}^nf_ig_{i-k}\)

\(g\) 的系数反过来,为 \(g'\),求出 \(h'=f\times g'\),则 \(h_k=h'_{k+n}\)


应用

有些式子,需要 \(O(n)\) 计算每一项,要计算全部的 \(n\)

但是把它拆成两个函数的卷积,这样就可以 \(O(n\log n)\) 计算出全部

P4491 [HAOI2018] 染色

数论知识杂记2 已经推出了式子

\[f_i={M\choose i}{N\choose Si}\frac{(Si)!}{(S!)^i}(M-i)^{N-Si} \]

\[g_k=\sum_{i=k}^{mn}(-1)^{i-k}{i\choose k}f_i \]

\(g_k\) 展开,

\[g_k=\sum_{i=k}^{mn}(-1)^{i-k}\frac{i!f_i}{k!(i-k)!} \]

\(g1(x)=i!f_i\)\(g2(x)=\frac{(-1)^x}{x!}\)

\(g_k=\frac 1 {k!}\sum_{i=k}^{mn}g1_ig2_{i-k}\)

这是差卷积的形式

于是用 NTT 优化,复杂度 \(O(n\log n)\)

void ntt(ll *f, ll n, ll typ)
{
	for(ll i = 0; i <= n; ++i)	rev[i] = 0;
	for(ll i = 0; i < n; ++i)	rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? (n >> 1) : 0);
	for(ll i = 0; i < n; ++i)
		if(i < rev[i])	swap(f[i], f[rev[i]]);
	for(ll step = 2; step <= n; step <<= 1)
	{
		ll gap = step >> 1, w = qmi(G, (mod - 1) / step);
		if(typ < 0)	w = qmi(w, mod - 2);
		for(ll i = 0; i < n; i += step)
		{
			ll cur = 1;
			for(ll j = i; j < i + gap; ++j, cur = cur * w % mod)
			{
				ll x = f[j], y = f[j + gap] * cur % mod;
				f[j] = add(x, y), f[j + gap] = add(x, mod - y); 
			}
		}
	}
} 
int main()
{
	n = read(), m = read(), s = read();
	mx = max(n, m);
	for(ll i = 0; i <= m; ++i)	w[i] = read(), typ |= (w[i] != 0);
	if(!typ)
	{
		printf("0");
		return 0;
	}
	mn = min(n / s, m), f[0] = qmi(m, n), mx = max(mx, mn << 1);
	fact[0] = invf[0] = 1;
	for(ll i = 1; i <= mx; ++i)	fact[i] = fact[i - 1] * i % mod;
	invf[mx] = qmi(fact[mx], mod - 2);
	for(ll i = mx - 1; i > 0; --i)	invf[i] = invf[i + 1] * (i + 1) % mod;
	for(ll i = 1; i <= mn; ++i)
		f[i] = c(m, i) * c(n, s * i) % mod * fact[s * i] % mod * qmi(invf[s], i) % mod * qmi(m - i, n - s * i) % mod;
	for(ll i = 0; i <= mn; ++i)	
		p[i] = (i & 1) ? add(0, mod - invf[i]) : invf[i], q[i] = fact[i] * f[i] % mod;
	reverse(p, p + mn + 1);
	for(lim = 1; lim <= mn << 1; lim <<= 1);
	ntt(p, lim, 1), ntt(q, lim, 1);
	for(ll i = 0; i < lim; ++i)	p[i] = p[i] * q[i] % mod;
	ntt(p, lim, -1), invl = qmi(lim, mod - 2);
	for(ll i = mn; i <= mn << 1; ++i)	g[i - mn] = p[i] * invl % mod * invf[i - mn] % mod;
	for(ll i = 0; i <= m; ++i)	ans = add(ans, g[i] * w[i] % mod);
	printf("%lld", ans);
	return 0;
}
posted @ 2023-07-09 16:05  KellyWLJ  阅读(10)  评论(0编辑  收藏  举报  来源