FFT快速傅里叶变换模板
Description
FFT快速傅里叶变换模板,在\(O(nlogn)\)的时间内求一个多项式与另一个多项式的乘积。
Input
第一行给出第一个多项式的次数上界\(len1\),也是第一个多项式的项数。
第二行给出\(len1\)个数,分别表示第一个多项式的\(0\)阶项,\(1\)阶项,\(\cdots\),\(len1-1\)阶项的系数。
第三行给出第二个多项式的次数上界\(len2\),也是第二个多项式的项数。
第四行给出\(len2\)个数,分别表示第二个多项式的\(0\)阶项,\(1\)阶项,\(\cdots\),\(len2-1\)阶项的系数。
Output
第一行输出乘积多项式的次数上界\(len\),也是乘积多项式的项数。
第二行输出\(len\)个数,分别表示乘积多项式的\(0\)阶项,\(1\)阶项,\(\cdots\),\(len-1\)阶项的系数。
Sample Input
4
1 -1 2 -1
3
2 1 2
Sample Output
6
2 -1 5 -2 3 -2
Code
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <complex>
#include <cmath>
using namespace std;
typedef long long ll;
const int INF = 0x3f3f3f3f;
const int N = 4e3 + 10;
const double PI = acos(-1);
typedef complex<double> Complex;
Complex A[N], B[N];
void Rader(Complex y[], int len)
{
for (int i = 1, j = len / 2; i < len - 1; i++)
{
if (i < j) swap(y[i], y[j]);
int k = len / 2;
while (j >= k) j -= k, k /= 2;
if (j < k) j += k;
}
}
void fft(Complex y[], int len, int op)
{
Rader(y, len);
for (int h = 2; h <= len; h <<= 1)
{
Complex wn(cos(op * 2 * PI / h), sin(op * 2 * PI / h));
for (int i = 0; i < len; i += h)
{
Complex w(1, 0);
for (int j = i; j < i + h / 2; j++)
{
Complex u = y[j];
Complex t = w * y[j + h / 2];
y[j] = u + t;
y[j + h / 2] = u - t;
w = w * wn;
}
}
}
if (op == -1)
for (int i = 0; i < len; i++)
y[i] = Complex(y[i].real() / len, y[i].imag());
}
int FFT(int a[], int len1, int b[], int len2)
{
int len = 1;
while (len < len1 * 2 || len < len2 * 2) len <<= 1;
for (int i = 0; i < len1; i++) A[i] = a[i];
for (int i = len1; i < len; i++) A[i] = 0;
for (int i = 0; i < len2; i++) B[i] = b[i];
for (int i = len2; i < len; i++) B[i] = 0;
fft(A, len, 1);
fft(B, len, 1);
for (int i = 0; i < len; i++) A[i] = A[i] * B[i];
fft(A, len, -1);
for (int i = 0; i < len; i++) a[i] = round(A[i].real());
while (len > 1 && a[len - 1] == 0) len--;
return len;
}
int a[N], b[N], c[N];
int main()
{
int len1;
scanf("%d", &len1);
for (int i = 0; i < len1; i++) scanf("%d", a + i);
int len2;
scanf("%d", &len2);
for (int i = 0; i < len2; i++) scanf("%d", b + i);
int len = FFT(a, len1, b, len2);
printf("%d\n", len);
for (int i = 0; i < len; i++) printf("%d ", a[i]);
printf("\n");
return 0;
}