[学习笔记]MTT
简单记录一下 \(\rm MTT\) 的推导和使用。
7 次 DFT 的 MTT
最朴素,并且最容易理解。
令 \(\text{bas} = \sqrt p\),那么,若计算 \(\Lambda=F\circ G\),我们可以考虑将每一项拆成两个部分:
可得
根据 \(\text{bas}\) 的次数区分,我们需要做 \(4\) 次 \(\rm DFT\),针对 \(c_1,c_2,r_1,r_2\),然后做三次 \(\rm IDFT\),针对 \(c_1\circ c_2,\;c_1\circ r_2+c_2\circ r_1,\;r_1\circ r_2\),常数比较大,但是易于推导与理解。
5 次 DFT 的 MTT
事实上和 4 次只有一步之遥了。
第一种方法为什么常数那么大?因为每一次变换的时候,我们将 \((h_i,0)\)(其中 \(h\) 为任意一个函数) 作为一个复数进行变换,每次的变换,初始的虚部均为 \(0\),这个无用却不得不计算的部分带给我们很大的开销,考虑能否将这个东西用起来。
我们还是考虑计算 \(\Lambda=(c_1\text{bas}+r_1)(x)+(c_2\text{bas}+r_2)(x)\),显然,我们其实想要求 \(c_1\circ c_2,\;c_1\circ r_2,\;r_1\circ c_2\;,r_1\circ r_2\) 这四个卷积,并不需要将他们单独求出来,可以构造一个结构使其出现这四个玩意,考虑到 \((a+bi)(c+di)=ac-bd+(ad+bc)i\),并且,如果在 \(b\) 之前添一个负号,同理可得 \(ac+bd+(ad-bc)i\),这两个东西上下相加,实部和虚部均只剩下一个,上下相减也只会剩下一个,于是我们就可以求出这四个玩意。具体地,令
则
于是
这四个东西分别单独出现在实部和虚部,显然可以直接求的。
接下来统计一下我们需要的变换次数:三次正变换,针对 \(P,\bar P,Q\),两次逆变换,针对 \(P(x)\circ Q(x)\) 和 \(\bar P(x)\circ Q(x)\). 常数大大减小了!
思路确实简单,不过难点在想到如此构造。
typedef complex<long double> cplx;
const int maxn = 1e5;
const long double Pi = acos(-1.0);
int mod;
namespace __poly {
int rev[maxn * 6 + 5], n;
inline void prepare(int len) {
for (n = 1; n < len; n <<= 1);
for (int i = 0; i < n; ++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1)? n >> 1: 0);
return ;
}
inline void fft(vector<cplx>& f, const int opt) {
f.resize(n);
for (int i = 0; i < n; ++i) if (i < rev[i])
swap(f[i], f[rev[i]]);
for (int p = 2; p <= n; p <<= 1) {
int len = p >> 1;
cplx w(cos(Pi / len), opt * sin(Pi / len));
for (int k = 0; k < n; k += p) {
cplx buf(1, 0), tmp;
for (int i = k; i < k + len; ++i, buf *= w) {
tmp = f[i + len] * buf;
f[i + len] = f[i] - tmp;
f[i] = f[i] + tmp;
}
}
}
if (opt == -1) for (int i = 0; i < n; ++i) f[i] /= n;
return ;
}
vector<cplx> P, _P, Q, mul1, mul2;
const int bas = 32000;
inline void mtt(vector<int>& f, const vector<int>& g) {
P.resize(f.size()), _P.resize(f.size()), Q.resize(g.size());
for (int i = 0; i < f.size(); ++i) {
P[i] = cplx(f[i] / bas, f[i] % bas);
_P[i] = conj(P[i]);
}
for (int i = 0; i < g.size(); ++i)
Q[i] = cplx(g[i] / bas, g[i] % bas);
prepare(f.size() + g.size());
fft(P, 1), fft(_P, 1), fft(Q, 1);
transform(P.begin(), P.end(), Q.begin(), P.begin(), multiplies<cplx>());
transform(_P.begin(), _P.end(), Q.begin(), _P.begin(), multiplies<cplx>());
fft(P, -1), fft(_P, -1), f.resize(n);
for (int i = 0; i < n; ++i) {
int c1c2 = (ll)((P[i].real() + _P[i].real()) * 0.5 + 0.5) % mod;
int c1r2 = (ll)((P[i].imag() + _P[i].imag()) * 0.5 + 0.5) % mod;
int r1c2 = (ll)((P[i].imag() - _P[i].imag()) * 0.5 + 0.5) % mod;
int r1r2 = (ll)((_P[i].real() - P[i].real()) * 0.5 + 0.5) % mod;
f[i] = 1ll * c1c2 * bas % mod * bas % mod;
f[i] = (f[i] + 1ll * (c1r2 + r1c2) * bas % mod) % mod;
f[i] = (f[i] + r1r2) % mod;
f[i] = (f[i] + mod) % mod;
}
return ;
}
} // using namespace __poly;
4 次 DFT 的 MTT
还可以更快!事实上,我们可以直接从 \(P(x)\) 得到 \(\bar P(x)\),先把 \(P(x)\) 变换出来,对于其共轭,假设 \(p_k=a+bi,\bar p_k=a-bi\),不难发现 \(\bar p_k\) 经过变换之后就是 \(p_{n-k}\) 经过变换得到的东西,因为共轭复数与原来的复数关于实轴对称,不过一个特例的是 \(\bar p_0=p_0\),特殊处理一下就好了。
typedef complex<long double> cplx;
const int maxn = 1e5;
const long double Pi = acos(-1.0);
int mod;
namespace __poly {
int rev[maxn * 6 + 5], n;
inline void prepare(int len) {
for (n = 1; n < len; n <<= 1);
for (int i = 0; i < n; ++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1)? n >> 1: 0);
return ;
}
inline void fft(vector<cplx>& f, const int opt) {
f.resize(n);
for (int i = 0; i < n; ++i) if (i < rev[i])
swap(f[i], f[rev[i]]);
for (int p = 2; p <= n; p <<= 1) {
int len = p >> 1;
cplx w(cos(Pi / len), opt * sin(Pi / len));
for (int k = 0; k < n; k += p) {
cplx buf(1, 0), tmp;
for (int i = k; i < k + len; ++i, buf *= w) {
tmp = f[i + len] * buf;
f[i + len] = f[i] - tmp;
f[i] = f[i] + tmp;
}
}
}
if (opt == -1) for (int i = 0; i < n; ++i) f[i] /= n;
return ;
}
vector<cplx> P, _P, Q, mul1, mul2;
const int bas = 32000;
inline void mtt(vector<int>& f, const vector<int>& g) {
P.resize(f.size()), Q.resize(g.size());
for (int i = 0; i < f.size(); ++i)
P[i] = cplx(f[i] / bas, f[i] % bas);
for (int i = 0; i < g.size(); ++i)
Q[i] = cplx(g[i] / bas, g[i] % bas);
prepare(f.size() + g.size());
fft(P, 1), fft(Q, 1);
_P.resize(n);
for (int i = 0; i < n; ++i) _P[i] = conj(P[i? n - i: 0]);
transform(P.begin(), P.end(), Q.begin(), P.begin(), multiplies<cplx>());
transform(_P.begin(), _P.end(), Q.begin(), _P.begin(), multiplies<cplx>());
fft(P, -1), fft(_P, -1), f.resize(n);
for (int i = 0; i < n; ++i) {
int c1c2 = (ll)((P[i].real() + _P[i].real()) * 0.5 + 0.5) % mod;
int c1r2 = (ll)((P[i].imag() + _P[i].imag()) * 0.5 + 0.5) % mod;
int r1c2 = (ll)((P[i].imag() - _P[i].imag()) * 0.5 + 0.5) % mod;
int r1r2 = (ll)((_P[i].real() - P[i].real()) * 0.5 + 0.5) % mod;
f[i] = 1ll * c1c2 * bas % mod * bas % mod;
f[i] = (f[i] + 1ll * (c1r2 + r1c2) * bas % mod) % mod;
f[i] = (f[i] + r1r2) % mod;
f[i] = (f[i] + mod) % mod;
}
return ;
}
} // using namespace __poly;