多项式乘法
多项式乘法
和卷积
若 \(h=f\times g\),则 \(h_n=\sum_{i=0}^n f_ig_{n-i}\)
朴素的乘法是 \(O(n^2)\) 的,看起来到了极限
但怎样再快一些?
发现如果把多项式换成点值表示,则 \(O(n)\) 即可计算!
问题是,用系数表示的次数为 \(n\) 的多项式化成点值就要 \(O(n^2)\)
于是引入了 系数 \(\to\) 点值 \(\to\) 系数的各种多项式乘法算法
FFT
前置知识:复数
DFT:通过系数算出点值
想快速计算点值,考虑取各种特殊点
但实数域内的特殊数有限,考虑扩展到复数域
对于次数为 \(n-1\) 的多项式,代入 \(n\) 个点值
取 \(\omega_n^{0\sim n-1}\),即单位复根
复平面上,它的终点把单位圆等分为 \(n\) 等份
它的性质:(可以用单位圆的图示理解)
-
\(\omega_n^0=\omega_n^n=1\) 在实轴上
-
\(\omega_n^k=-\omega_n^{k+\frac n 2}\) 关于原点对称
-
\(\omega_n^k=\omega_n^{n+k}\) 转了一圈
-
\(\omega_n^k=\omega_{2n}^{2k}\) 辐角一样
然后,把多项式奇偶次项分开
\(f(x)=a_0x^0+a_2x^2+a_4x^4\dots+a_1x_1+a_3x^3+a_5x^5\)
令 \(g(x)=a_0x^0+a_2x^1+a_4x^2\dots\),\(h(x)=a_1x^0+a_3x^1+a_5x^2\)
\(f(x)=g(x^2)+x\times h(x^2)\)
发现多项式的次数仅为原来的 \(\frac 1 2\)!
代入 \(n\) 个单位复根
\(f(\omega_n^i)=g((\omega_n^i)^2)+\omega_n^i\times h((\omega_n^i)^2)\)
\(f(\omega_n^{i+\frac n 2})=g((\omega_n^{i+\frac n 2})^2)+\omega_n^{i+\frac n 2}\times h((\omega_n^{i+\frac n 2})^2)=g((\omega_n^i)^2)-\omega_n^i\times h((\omega_n^i)^2)\)
发现这两个式子只有一个常数项不同,于是只计算前面一半可以直接得到后面
那么,分成两半递归分治计算 \(f\),复杂度是 \(O(n\log n)\)
IDFT:点值重新变为系数
证明看不会了……其实是懒得看,况且不知道好像没关系
感性理解一下,“相反”,把单位复根全部变为它的共轭复数即可
结论有:第 \(i\) 项最后的实数部分为 \(c_i\),则系数 \(a_i=\frac{c_i}n\)
注意实现时拿 STL 中的 complex<double>
,用 real()
取出实数部分
因为每次分治,所以次数应补全为 \(lim=2^k\)
递归版代码:
void fft(comp *f, ll n, ll typ)
{
if(n == 1) return;
for(ll i = 0; i < n; ++i) tmp[i] = f[i];
for(ll i = 0; i < n; ++i) // 奇偶分开
if(i & 1) f[(i >> 1) + (n >> 1)] = tmp[i];
else f[i >> 1] = tmp[i];
comp *g = f, *h = f + n / 2;
fft(g, n >> 1, typ), fft(h, n >> 1, typ); // 递归计算
comp cur(1, 0), w(cos((double)2.0 * pi / n), sin((double)2.0 * pi / n) * typ);
for(ll i = 0; i < (n >> 1); ++i, cur *= w) // 算出当前的值
tmp[i] = g[i] + h[i] * cur, tmp[i + (n >> 1)] = g[i] - h[i] * cur;
for(ll i = 0; i < n; ++i) f[i] = tmp[i]; // 更新
}
int main()
{
n = read(), m = read();
for(ll i = 0; i <= n; ++i) a[i] = {(double)read(), 0.0};
for(ll i = 0; i <= m; ++i) b[i] = {(double)read(), 0.0};
for(; sizn <= n + m; sizn <<= 1);
fft(a, sizn, 1), fft(b, sizn, 1);
for(ll i = 0; i <= sizn; ++i) a[i] = a[i] * b[i];
fft(a, sizn, -1);
for(ll i = 0; i <= n + m; ++i) print((ll)(a[i].real() / sizn + 0.5)), putchar(' ');
return 0;
}
优化
递归版常数比较大,加上复数运算自带大常数
于是想变成非递归版
一组一组的分治下去,要知道最后一组两两配对的情况
这里把每个数与它的二进制表示反转后表示的数交换(高位补的 0 到前面)
可以递推求出
然后枚举每层的次数,按上述步骤计算即可
void fft(comp *f, ll n, ll typ) // 非递归版,自底向上模拟递归过程
{
for(ll i = 0; i < n; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? (n >> 1) : 0);
for(ll i = 0; i < n; ++i) // 交换顺序
if(i < rev[i]) swap(f[i], f[rev[i]]);
for(ll step = 2; step <= n; step <<= 1) // 当前计算的组内的元素个数
{
ll gap = step >> 1;
comp w(cos(pi * (double)2.0 / step), typ * sin(pi * (double)2.0 / step));
for(ll i = 0; i < n; i += step) // 跳到下一组
{
comp cur(1.0, 0.0);
for(ll j = i; j < i + gap; ++j, cur *= w) // 与下层合并
{
comp p = f[j], q = cur * f[j + gap];
f[j] = p + q, f[j + gap] = p - q;
}
}
}
}
int main()
{
n = read(), m = read();
for(ll i = 0; i <= n; ++i) a[i] = {(double)read(), 0.0};
for(ll i = 0; i <= m; ++i) b[i] = {(double)read(), 0.0};
for(; sizn <= n + m; sizn <<= 1);
fft(a, sizn, 1), fft(b, sizn, 1);
for(ll i = 0; i <= sizn; ++i) a[i] = a[i] * b[i];
fft(a, sizn, -1);
for(ll i = 0; i <= n + m; ++i) print((ll)(a[i].real() / sizn + 0.5)), putchar(' ');
return 0;
}
NTT
如果多项式系数较大,题目中需要取模,就不好用复数计算
那有什么在模意义下能代替复数呢?
模数为质数,那它的原根 \(g^{p-1}\equiv1(\bmod\ p)\),且 \(g^{0\sim p-2}\) 互不相同
它甚至有与单位复根相似的性质:
-
\(g^n\equiv g^{n+\frac {p-1}2}(\bmod\ p)\)
-
\(g^n\equiv g^{n+p-1}(\bmod\ p)\)
那么,直接把 FFT 中的单位复根换成原根即可
但对模数要求较为苛刻,因为把 \(p-1\) \(2^k\) 等分并不容易
\(998244353=7\times17\times2^{23}+1,g=3\)
\(1004535809=479\times 2^{21}+1,g=3\)
这两个是常见的模数
代码:
void ntt(ll *f, ll n, ll typ)
{
for(ll i = 0; i < n; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? (n >> 1) : 0);
for(ll i = 0; i < n; ++i)
if(i < rev[i]) swap(f[i], f[rev[i]]);
for(ll step = 2; step <= n; step <<= 1)
{
ll gap = step >> 1, w = qmi(g, (mod - 1) / step);
if(typ < 0) w = qmi(w, mod - 2);
for(ll i = 0; i < n; i += step)
{
ll cur = 1;
for(ll j = i; j < i + gap; ++j, cur = cur * w % mod)
{
ll p = f[j], q = f[j + gap] * cur % mod;
f[j] = add(p, q), f[j + gap] = add(p, mod - q);
}
}
}
}
int main()
{
n = read(), m = read();
for(ll i = 0; i <= n; ++i) a[i] = read();
for(ll i = 0; i <= m; ++i) b[i] = read();
for(; lim <= n + m; lim <<= 1);
ntt(a, lim, 1), ntt(b, lim, 1);
for(ll i = 0; i < lim; ++i) a[i] = a[i] * b[i] % mod;
ntt(a, lim, -1);
inv = qmi(lim, mod - 2);
for(ll i = 0; i <= n + m; ++i) print(a[i] * inv % mod), putchar(' ');
return 0;
}
循环卷积
\(h_k=\sum_{i=0}^n\sum_{j=0}^n [(i+j)\bmod (n+1)=k]f_ig_j\)
求出 \(h'=f\times g\),则
\(h_k=h'_k+h'_{k+n+1}(k=0\sim n-1)\)
\(h_n=h'_n\)
差卷积
\(h_k=\sum_{i=k}^nf_ig_{i-k}\)
把 \(g\) 的系数反过来,为 \(g'\),求出 \(h'=f\times g'\),则 \(h_k=h'_{k+n}\)
应用
有些式子,需要 \(O(n)\) 计算每一项,要计算全部的 \(n\) 项
但是把它拆成两个函数的卷积,这样就可以 \(O(n\log n)\) 计算出全部
在 数论知识杂记2 已经推出了式子
把 \(g_k\) 展开,
令 \(g1(x)=i!f_i\),\(g2(x)=\frac{(-1)^x}{x!}\)
则 \(g_k=\frac 1 {k!}\sum_{i=k}^{mn}g1_ig2_{i-k}\)
这是差卷积的形式
于是用 NTT 优化,复杂度 \(O(n\log n)\)
void ntt(ll *f, ll n, ll typ)
{
for(ll i = 0; i <= n; ++i) rev[i] = 0;
for(ll i = 0; i < n; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? (n >> 1) : 0);
for(ll i = 0; i < n; ++i)
if(i < rev[i]) swap(f[i], f[rev[i]]);
for(ll step = 2; step <= n; step <<= 1)
{
ll gap = step >> 1, w = qmi(G, (mod - 1) / step);
if(typ < 0) w = qmi(w, mod - 2);
for(ll i = 0; i < n; i += step)
{
ll cur = 1;
for(ll j = i; j < i + gap; ++j, cur = cur * w % mod)
{
ll x = f[j], y = f[j + gap] * cur % mod;
f[j] = add(x, y), f[j + gap] = add(x, mod - y);
}
}
}
}
int main()
{
n = read(), m = read(), s = read();
mx = max(n, m);
for(ll i = 0; i <= m; ++i) w[i] = read(), typ |= (w[i] != 0);
if(!typ)
{
printf("0");
return 0;
}
mn = min(n / s, m), f[0] = qmi(m, n), mx = max(mx, mn << 1);
fact[0] = invf[0] = 1;
for(ll i = 1; i <= mx; ++i) fact[i] = fact[i - 1] * i % mod;
invf[mx] = qmi(fact[mx], mod - 2);
for(ll i = mx - 1; i > 0; --i) invf[i] = invf[i + 1] * (i + 1) % mod;
for(ll i = 1; i <= mn; ++i)
f[i] = c(m, i) * c(n, s * i) % mod * fact[s * i] % mod * qmi(invf[s], i) % mod * qmi(m - i, n - s * i) % mod;
for(ll i = 0; i <= mn; ++i)
p[i] = (i & 1) ? add(0, mod - invf[i]) : invf[i], q[i] = fact[i] * f[i] % mod;
reverse(p, p + mn + 1);
for(lim = 1; lim <= mn << 1; lim <<= 1);
ntt(p, lim, 1), ntt(q, lim, 1);
for(ll i = 0; i < lim; ++i) p[i] = p[i] * q[i] % mod;
ntt(p, lim, -1), invl = qmi(lim, mod - 2);
for(ll i = mn; i <= mn << 1; ++i) g[i - mn] = p[i] * invl % mod * invf[i - mn] % mod;
for(ll i = 0; i <= m; ++i) ans = add(ans, g[i] * w[i] % mod);
printf("%lld", ans);
return 0;
}