[笔记]ACM笔记 - 利用FFT求卷积(求多项式乘法)
卷积
给定向量:
向量和:
数量积(内积、点积):
卷积:
例如:
卷积的最典型的应用就是多项式乘法(多项式乘法就是求卷积)。以下就用多项式乘法来描述、举例卷积与DFT。
关于多项式
对于多项式
多项式的系数表达方式:
则多项式的系数向量即为
多项式的点值表达方式:
离散傅里叶变换(DFT)
离散傅里叶变换(Discrete Fourier Transform,DFT)。在信号处理很重要的一个东西,这里物理意义以及其他应用暂不予理睬。在多项式中,DFT就是系数表式转换成点值表示的过程。
快速傅里叶变换(FFT)
快速傅里叶变换(Fast Fourier Transformation,FFT):快速计算DFT的算法,能够在
求FFT要用到复数。一个简单的模板:
struct Complex // 复数
{
double r, i;
Complex(double _r = 0, double _i = 0) :r(_r), i(_i) {}
Complex operator +(const Complex &b) {
return Complex(r + b.r, i + b.i);
}
Complex operator -(const Complex &b) {
return Complex(r - b.r, i - b.i);
}
Complex operator *(const Complex &b) {
return Complex(r*b.r - i*b.i, r*b.i + i*b.r);
}
};
递归实现FFT模板:来源
Complex* RecursiveFFT(Complex a[], int n)//n表示向量a的维数
{
if(n == 1)
return a;
Complex wn = Complex(cos(2*PI/n), sin(2*PI/n));
Complex w = Complex(1, 0);
Complex* a0 = new Complex[n >> 1];
Complex* a1 = new Complex[n >> 1];
for(int i = 0; i < n; i++)
if(i & 1) a1[(i - 1) >> 1] = a[i];
else a0[i >> 1] = a[i];
Complex *y0, *y1;
y0 = RecursiveFFT(a0, n >> 1);
y1 = RecursiveFFT(a1, n >> 1);
Complex* y = new Complex[n];
for(int k = 0; k < (n >> 1); k++)
{
y[k] = y0[k] + w*y1[k];
y[k + (n >> 1)] = y0[k] - w*y1[k];
w = w*wn;
}
delete a0;
delete a1;
delete y0;
delete y1;
return y;
}
非递归实现。模板:(来源忘了)
void change(Complex y[], int len) // 二进制平摊反转置换 O(logn)
{
int i, j, k;
for (i = 1, j = len / 2;i < len - 1;i++)
{
if (i < j)swap(y[i], y[j]);
k = len / 2;
while (j >= k)
{
j -= k;
k /= 2;
}
if (j < k)j += k;
}
}
void fft(Complex y[], int len, int on) //FFT:on=1; IFFT:on=-1
{
change(y, len);
for (int h = 2;h <= len;h <<= 1)
{
Complex wn(cos(-on * 2 * PI / h), sin(-on * 2 * PI / h));
for (int j = 0;j < len;j += h)
{
Complex w(1, 0);
for (int k = j;k < j + h / 2;k++)
{
Complex u = y[k];
Complex t = w*y[k + h / 2];
y[k] = u + t;
y[k + h / 2] = u - t;
w = w*wn;
}
}
}
if (on == -1)
for (int i = 0;i < len;i++)
y[i].r /= len;
}
利用FFT求卷积
普通的计算多项式乘法的计算,时间复杂度
步骤一(补0)
在两个多项式前面补0,得到两个2n次多项式,设系数向量分别为
步骤二(求值)
使用FFT计算
步骤三(乘法)
把
步骤四(插值)
使用IFFT计算
综上
fft(x1, len, 1);
fft(x2, len, 1);
for (int i = 0;i < len;i++) {
x[i] = x1[i] * x2[i];
}
fft(x, len, -1);
例题
1.2016 acm香港网络赛 A题 A+B Problem
网上的代码(当时没保留出处。。。)
#include <algorithm>
#include <cstring>
#include <string.h>
#include <iostream>
#include <list>
#include <map>
#include <set>
#include <stack>
#include <string>
#include <utility>
#include <vector>
#include <cstdio>
#include <cmath>
#define LL long long
#define N 200005
#define INF 0x3ffffff
using namespace std;
const double PI = acos(-1.0);
struct Complex // 复数
{
double r, i;
Complex(double _r = 0, double _i = 0) :r(_r), i(_i) {}
Complex operator +(const Complex &b)
{
return Complex(r + b.r, i + b.i);
}
Complex operator -(const Complex &b)
{
return Complex(r - b.r, i - b.i);
}
Complex operator *(const Complex &b)
{
return Complex(r*b.r - i*b.i, r*b.i + i*b.r);
}
};
void change(Complex y[], int len) // 二进制平摊反转置换 O(logn)
{
int i, j, k;
for (i = 1, j = len / 2;i < len - 1;i++)
{
if (i < j)swap(y[i], y[j]);
k = len / 2;
while (j >= k)
{
j -= k;
k /= 2;
}
if (j < k)j += k;
}
}
void fft(Complex y[], int len, int on) //DFT和FFT
{
change(y, len);
for (int h = 2;h <= len;h <<= 1)
{
Complex wn(cos(-on * 2 * PI / h), sin(-on * 2 * PI / h));
for (int j = 0;j < len;j += h)
{
Complex w(1, 0);
for (int k = j;k < j + h / 2;k++)
{
Complex u = y[k];
Complex t = w*y[k + h / 2];
y[k] = u + t;
y[k + h / 2] = u - t;
w = w*wn;
}
}
}
if (on == -1)
for (int i = 0;i < len;i++)
y[i].r /= len;
}
const int M = 50000; // a数组所有元素+M,使a[i]>=0
const int MAXN = 800040;
Complex x1[MAXN];
int a[MAXN / 4]; //原数组
long long num[MAXN]; //利用FFT得到的数组
long long tt[MAXN]; //统计数组每个元素出现个数
int main()
{
int n = 0; // n表示除了0之外数组元素个数
int tot;
scanf("%d", &tot);
memset(num, 0, sizeof(num));
memset(tt, 0, sizeof(tt));
int cnt0 = 0; //cnt0 统计0的个数
int aa;
for (int i = 0;i < tot;i++)
{
scanf("%d", &aa);
if (aa == 0) { cnt0++;continue; } //先把0全删掉,最后特殊考虑0
else a[n] = aa;
num[a[n] + M]++;
tt[a[n] + M]++;
n++;
}
sort(a, a + n);
int len1 = a[n - 1] + M + 1;
int len = 1;
while (len < 2 * len1) len <<= 1;
for (int i = 0;i < len1;i++) {
x1[i] = Complex(num[i], 0);
}
for (int i = len1;i < len;i++) {
x1[i] = Complex(0, 0);
}
fft(x1, len, 1);
for (int i = 0;i < len;i++) {
x1[i] = x1[i] * x1[i];
}
fft(x1, len, -1);
for (int i = 0;i < len;i++) {
num[i] = (long long)(x1[i].r + 0.5);
}
len = 2 * (a[n - 1] + M);
for (int i = 0;i < n;i++) //删掉ai+ai的情况
num[a[i] + a[i] + 2 * M]--;
/*
for(int i = 0;i < len;i++){
if(num[i]) cout<<i-2*M<<' '<<num[i]<<endl;
}
*/
long long ret = 0;
int l = a[n - 1] + M;
for (int i = 0;i <= l; i++) //ai,aj,ak都不为0的情况
{
if (tt[i]) ret += (long long)(num[i + M] * tt[i]);
}
ret += (long long)(num[2 * M] * cnt0); // ai+aj=0的情况
if (cnt0 != 0)
{
if (cnt0 >= 3) { //ai,aj,ak都为0的情况
long long tmp = 1;
tmp *= (long long)(cnt0);
tmp *= (long long)(cnt0 - 1);
tmp *= (long long)(cnt0 - 2);
ret += tmp;
}
for (int i = 0;i <= l; i++)
{
if (tt[i] >= 2) { // x+0=x的情况
long long tmp = (long long)cnt0;
tmp *= (long long)(tt[i]);
tmp *= (long long)(tt[i] - 1);
ret += tmp * 2;
}
}
}
printf("%lld\n", ret);
return 0;
}