FFT & NTT
FFT 快速傅里叶变换
\(O(nlogn)\) 计算多项式乘法
由 系数表示法 转换为 点值表示法
记
\[\omega_n^k = cos(\dfrac {2\pi\cdot k} n) + i \cdot sin(\dfrac {2\pi \cdot k} n)
\]
\[A(x)=a_0+a_1*x+a_2*{x^2}+a_3*{x^3}+a_4*{x^4}+a_5*{x^5}+\\ \dots+a_{n-2}*x^{n-2}+a_{n-1}*x^{n-1}
\]
\[A(x)=(a_0+a_2*{x^2}+a_4*{x^4}+\dots+a_{n-2}*x^{n-2})+\\(a_1*x+a_3*{x^3}+a_5*{x^5}+ \dots+a_{n-1}*x^{n-1})
\]
设
\[A_1(x)=a_0+a_2*{x}+a_4*{x^2}+\dots+a_{n-2}*x^{\frac{n}{2}-1}
\]
\[A_2(x)=a_1+a_3*{x}+a_5*{x^2}+ \dots+a_{n-1}*x^{\frac{n}{2}-1}
\]
则
\[A(x)=A_1(x^2)+xA_2(x^2)
\]
带入 \(x = \omega_n^k\)
\[A(\omega_n^k) = A_1(\omega_{\frac n2}^k) + \omega_n^kA_2(\omega_{\frac n2}^k)
\]
带入 $x = \omega_n^{k+\frac n2} $
\[A(\omega_n^{k+\frac n2}) = A_1(\omega_{\frac n2}^k) -\omega_n^kA_2(\omega_{\frac n2}^k)
\]
也就是说如果知道了 $A_1(x),A_2(x) $ 分别在 \(\omega_{\frac n2}^0\) , \(\omega_{\frac n2}^1\) , \(\omega_{\frac n2}^2\) ,...,\(\omega_{\frac n2}^{\frac n2 -1}\) 的取值,
就可以 \(O(n)\) 的求出 \(A(x)\)
void fft(cp *a,int n,int inv)//inv是取共轭复数的符号
{
if (n==1)return;
int mid=n/2;
static cp b[MAXN];
for(int i = 0;i < mid;i++)b[i]=a[i*2],b[i+mid]=a[i*2+1];
for(int i = 0;i < n;i++)a[i]=b[i];
fft(a,mid,inv),fft(a+mid,mid,inv);//分治
for(int i = 0;i < mid;i++)
{
cp x(cos(2*pi*i/n),inv*sin(2*pi*i/n));//inv取决是否取共轭复数
b[i]=a[i]+x*a[i+mid],b[i+mid]=a[i]-x*a[i+mid];
}
for(int i = 0;i < a;i++)a[i]=b[i];
}
每个位置分治后最终的位置是二进制翻转后的位置
void fft(cp *a,int n,int inv)
{
int bit=0;
while ((1<<bit)<n)bit++;
fo(i,0,n-1)
{
rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
if (i<rev[i])swap(a[i],a[rev[i]]);//不加这条if会交换两次(就是没交换)
}
for (int mid=1;mid<n;mid*=2)//mid是准备合并序列的长度的二分之一
{
cp temp(cos(pi/mid),inv*sin(pi/mid));//单位根,pi的系数2已经约掉了
for (int i=0;i<n;i+=mid*2)//mid*2是准备合并序列的长度,i是合并到了哪一位
{
cp omega(1,0);
for (int j=0;j<mid;j++,omega*=temp)//只扫左半部分,得到右半部分的答案
{
cp x=a[i+j],y=omega*a[i+j+mid];
a[i+j]=x+y,a[i+j+mid]=x-y;//这个就是蝴蝶变换什么的
}
}
}
}
注意 lim
#include<bits/stdc++.h>
using namespace std;
const double pi = acos(-1.0);
const int N = 3e6 + 10;
struct cp {
double x, y;
cp() {}
cp(double _x, double _y) {
x = _x; y = _y;
}
cp operator + (cp b) {
return cp(x + b.x, y + b.y);
}
cp operator -(cp b) {
return cp(x - b.x, y - b.y);
}
cp operator *(cp b) {
return cp(x * b.x - y * b.y, x * b.y + y * b.x);
}
};
int rev[N];
int bit = 0;
int lim;
void FFT(cp* a, int inv) {
for (int i = 0; i < lim; i++) {
if (i < rev[i]) {
swap(a[i], a[rev[i]]);
}
}
for (int mid = 1; mid < lim; mid <<= 1) {
cp temp(cos(pi / mid), inv * sin(pi / mid));
for (int i = 0; i < lim; i += mid * 2) {
cp omega(1, 0);
for (int j = 0; j < mid; j++, omega = omega * temp) {
cp x = a[i + j], y = omega * a[i + j + mid];
a[i + j] = x + y, a[i + j + mid] = x - y;
}
}
}
}
int n, m;
cp A[N], B[N];
int main() {
scanf("%d%d", &n, &m);
lim = 1;
while (lim <= n + m)lim<<=1,bit++;//调整至 2^k
for (int i = 0; i < lim; i++) {
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}
for (int i = 0; i <= n; i++)scanf("%lf", &A[i].x), A[i].y = 0;
for (int i = 0; i <= m; i++)scanf("%lf", &B[i].x), B[i].y = 0;
FFT(A, 1);
FFT(B, 1);
for (int i = 0; i <= lim; i++) {
A[i] = A[i] * B[i];
}
FFT(A, -1);
for (int i = 0; i <= n + m; i++) {
printf("%d ", int(A[i].x /lim+0.5));
}
}
NTT
原根
还没有整太明白
待补,丢一个板子
#include<bits/stdc++.h>
#define swap(a,b) (a^=b,b^=a,a^=b)
using namespace std;
#define LL long long
const int MAXN = 3 * 1e6 + 10, P = 998244353, G = 3, Gi = 332748118;
char buf[1 << 21], * p1 = buf, * p2 = buf;
int N, M, limit = 1, L, r[MAXN];
LL a[MAXN], b[MAXN];
inline LL fastpow(LL a, LL k) {
LL base = 1;
while (k) {
if (k & 1) base = (base * a) % P;
a = (a * a) % P;
k >>= 1;
}
return base % P;
}
inline void NTT(LL* A, int type) {
for (int i = 0; i < limit; i++)
if (i < r[i]) swap(A[i], A[r[i]]);
for (int mid = 1; mid < limit; mid <<= 1) {
LL Wn = fastpow(type == 1 ? G : Gi, (P - 1) / (mid << 1));
for (int j = 0; j < limit; j += (mid << 1)) {
LL w = 1;
for (int k = 0; k < mid; k++, w = (w * Wn) % P) {
int x = A[j + k], y = w * A[j + k + mid] % P;
A[j + k] = (x + y) % P,
A[j + k + mid] = (x - y + P) % P;
}
}
}
}
int main() {
scanf("%d%d", &N, &M);
for (int i = 0; i <= N; i++) scanf("%d", a + i);
for (int i = 0; i <= M; i++) scanf("%d", b + i);
while (limit <= N + M) limit <<= 1, L++;
for (int i = 0; i < limit; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));
NTT(a, 1); NTT(b, 1);
for (int i = 0; i < limit; i++) a[i] = (a[i] * b[i]) % P;
NTT(a, -1);
LL inv = fastpow(limit, P - 2);
for (int i = 0; i <= N + M; i++)
printf("%d ", (a[i] * inv) % P);
return 0;
}