FFT快速傅里叶变换模板

Description

FFT快速傅里叶变换模板,在\(O(nlogn)\)的时间内求一个多项式与另一个多项式的乘积。

Input

第一行给出第一个多项式的次数上界\(len1\),也是第一个多项式的项数。
第二行给出\(len1\)个数,分别表示第一个多项式的\(0\)阶项,\(1\)阶项,\(\cdots\)\(len1-1\)阶项的系数。
第三行给出第二个多项式的次数上界\(len2\),也是第二个多项式的项数。
第四行给出\(len2\)个数,分别表示第二个多项式的\(0\)阶项,\(1\)阶项,\(\cdots\)\(len2-1\)阶项的系数。

Output

第一行输出乘积多项式的次数上界\(len\),也是乘积多项式的项数。
第二行输出\(len\)个数,分别表示乘积多项式的\(0\)阶项,\(1\)阶项,\(\cdots\)\(len-1\)阶项的系数。

Sample Input

4
1 -1 2 -1
3
2 1 2

Sample Output

6
2 -1 5 -2 3 -2

Code

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <complex>
#include <cmath>
using namespace std;
typedef long long ll;
const int INF = 0x3f3f3f3f;
const int N = 4e3 + 10;

const double PI = acos(-1);
typedef complex<double> Complex;

Complex A[N], B[N];

void Rader(Complex y[], int len)
{
    for (int i = 1, j = len / 2; i < len - 1; i++)
    {
        if (i < j) swap(y[i], y[j]);
        int k = len / 2;
        while (j >= k) j -= k, k /= 2;
        if (j < k)  j += k;
    }
}

void fft(Complex y[], int len, int op)
{
    Rader(y, len);
    for (int h = 2; h <= len; h <<= 1)
    {
        Complex wn(cos(op * 2 * PI / h), sin(op * 2 * PI / h));
        for (int i = 0; i < len; i += h)
        {
            Complex w(1, 0);
            for (int j = i; j < i + h / 2; j++)
            {
                Complex u = y[j];
                Complex t = w * y[j + h / 2];
                y[j] = u + t;
                y[j + h / 2] = u - t;
                w = w * wn;
            }
        }
    }
    if (op == -1) 
        for (int i = 0; i < len; i++) 
            y[i] = Complex(y[i].real() / len, y[i].imag());
}

int FFT(int a[], int len1, int b[], int len2)
{
    int len = 1;
    while (len < len1 * 2 || len < len2 * 2) len <<= 1;
    for (int i = 0; i < len1; i++) A[i] = a[i];
    for (int i = len1; i < len; i++) A[i] = 0;
    for (int i = 0; i < len2; i++) B[i] = b[i];
    for (int i = len2; i < len; i++) B[i] = 0;
    fft(A, len, 1);
    fft(B, len, 1);
    for (int i = 0; i < len; i++) A[i] = A[i] * B[i];
    fft(A, len, -1);
    for (int i = 0; i < len; i++) a[i] = round(A[i].real());
    while (len > 1 && a[len - 1] == 0) len--;
    return len;
}

int a[N], b[N], c[N];

int main()
{
    int len1;
    scanf("%d", &len1);
    for (int i = 0; i < len1; i++) scanf("%d", a + i);

    int len2;
    scanf("%d", &len2);
    for (int i = 0; i < len2; i++) scanf("%d", b + i);

    int len = FFT(a, len1, b, len2);
    
    printf("%d\n", len);
    for (int i = 0; i < len; i++) printf("%d ", a[i]);
    printf("\n");
    return 0;
}
posted @ 2017-08-08 23:10  达达Mr_X  阅读(184)  评论(0编辑  收藏  举报