MTT学习小记
这是个毒瘤题才有的毒瘤东西……奶一口NOI不考
拆系数FFT:
考虑做NTT时模数不是NTT模数(\(2^a*b+1\))怎么办?
很容易想到拆次数FFT。
比如说现在求\(a*b\),设\(m=\sqrt mo(2^{15})\)
那么把\(a[i]\)拆成\(a0[i]+a1[i]*m\),b[i]拆成\(b0[i]+b1[i]*m\)
那么\(a[i]*b[j]=a0[i]*b0[j]+(a0[i]*b1[j]+a1[i]*b0[j])*m+(a1[i]*b1[j])*m^2\)
由于\(a0,b1,b0,b1\)的大小都不到,所以FFT不会爆精度。
那么这个最好也需要4(正)+3(逆)=7,复杂度不能接受。
DFT:
一开始要做4次DFT,我们两两一起做。
假设现在有两个序列A、B,要求DFT(A)和DFT(B)
设\(P=A+B*i,Q=A-B*i\)
只用做P的DFT,便可得到Q的DFT。
\(DFT(Q)[n-i]=conj(DFT(P)[i])\),\(conj\)为共轭,就是把虚部系数取反。
证明:
\(conj(DFT(P)[i])\)
\(=conj(\sum_{j=0}^{n-1}w_n^{ji}*(A[i]+B[i]*\sqrt{-1})\)
\(=\sum_{j=0}^{n-1}conj(w_n^{ji})*conj(A[i]+B[i]*\sqrt{-1})\)
\(=\sum_{j=0}^{n-1}w_n^{-ij}*(A[i]-B[i]*\sqrt{-1})\)
\(=DFT(Q)[n-i]\)
这样求出了\(DFT(P),DFT(Q)\)
那么\(DFT(A)=(DFT(P)+DFT(Q))*({1\over2},0),\\DFT(B)=(DFT(P)-DFT(Q))*(-1)*(0, {1\over2})\)
这个东西显然有一个条件是\(A、B\)只能实部有值,不然会混乱了无法提出来的。
IDFT:
接下来是IDFT,同样的可以两个 一起做。
如果有\(A、B\),都只有实部有值,设\(C=DFT(A)+i*DFT(B)\)
显然\(IDFT(C)\)的实部就是A,虚部就是B
这样我们就用四次DFT完成啦!
Code:
#include<bits/stdc++.h>
#define fo(i, x, y) for(int i = x, B = y; i <= B; i ++)
#define ff(i, x, y) for(int i = x, B = y; i < B; i ++)
#define fd(i, x, y) for(int i = x, B = y; i >= B; i --)
#define ll long long
#define pp printf
#define hh pp("\n")
#define db double
using namespace std;
const db pi = acos(-1);
const int mo = 1e9 + 7;
struct P {
db x, y;
P(db _x = 0, db _y = 0) { x = _x, y = _y;}
P operator + (P b) { return P(x + b.x, y + b.y);}
P operator - (P b){ return P(x - b.x, y - b.y);}
P operator * (P b) { return P(x * b.x - y * b.y, x * b.y + y * b.x);}
};
const int nm = 1 << 18;
P w[nm]; int r[nm];
P c0[nm], c1[nm], c2[nm], c3[nm];
void dft(P *a, int n) {
ff(i, 0, n) {
r[i] = r[i / 2] / 2 + (i & 1) * (n / 2);
if(i < r[i]) swap(a[i], a[r[i]]);
} P b;
for(int i = 1; i < n; i *= 2) for(int j = 0; j < n; j += 2 * i)
ff(k, 0, i) b = a[i + j + k] * w[i + k], a[i + j + k] = a[j + k] - b, a[j + k] = a[j + k] + b;
}
void rev(P *a, int n) {
reverse(a + 1, a + n);
ff(i, 0, n) a[i].x /= n, a[i].y /= n;
}
P conj(P a) { return P(a.x, -a.y);}
void fft(ll *a, ll *b, int n) {
#define qz(x) ((ll) round(x))
// ff(i, 0, n) c0[i] = P(a[i], 0), c1[i] = P(b[i], 0);
// dft(c0, n); dft(c1, n);
// ff(i, 0, n) c0[i] = c0[i] * c1[i];
// dft(c0, n); rev(c0, n);
// ff(i, 0, n) a[i] = qz(c0[i].x);
ff(i, 0, n) c0[i] = P(a[i] & 32767, a[i] >> 15), c1[i] = P(b[i] & 32767, b[i] >> 15);
dft(c0, n); dft(c1, n);
ff(i, 0, n) {
P k, d0, d1, d2, d3;
int j = (n - i) & (n - 1);
k = conj(c0[j]);
d0 = (k + c0[i]) * P(0.5, 0);
d1 = (k - c0[i]) * P(0, 0.5);
k = conj(c1[j]);
d2 = (k + c1[i]) * P(0.5, 0);
d3 = (k - c1[i]) * P(0, 0.5);
c2[i] = d0 * d2 + d1 * d3 * P(0, 1);
c3[i] = d0 * d3 + d1 * d2;
}
dft(c2, n); dft(c3, n); rev(c2, n); rev(c3, n);
ff(i, 0, n) {
a[i] = qz(c2[i].x) + (qz(c2[i].y) % mo << 30) + (qz(c3[i].x) % mo << 15);
a[i] %= mo;
}
}
ll a[nm], b[nm];
int main() {
for(int i = 1; i < nm; i *= 2) ff(j, 0, i)
w[i + j] = P(cos(pi * j / i), sin(pi * j / i));
fo(i, 0, 15) a[i] = b[i] = mo - 1;
fft(a, b, 32);
ff(i, 0, 32) pp("%lld ", a[i]);
}