[学习笔记]FFT——快速傅里叶变换

大力推荐博客:

傅里叶变换(FFT)学习笔记

 

一、多项式乘法:


我们要明白的是:

FFT利用分治,处理多项式乘法,达到O(nlogn)的复杂度。(虽然常数大)

FFT=DFT+IDFT

DFT:

本质是把多项式的系数表达转化为点值表达。因为点值表达,y可以直接相乘。点值表达下相乘的复杂度是O(n)的。

我们分别对两个多项式求x为$\omega_n^i$时的y值。

然后可以O(n)求出乘积多项式x为$\omega_n^i$时的y值。

求法:

把F(x)奇偶分类。

$FL(x)=a_0+a_2x+...+a_{n-2}x^{n/2-1}$

$FR(x)=a_1+a_3x+...+a_{n-1}x^{n/2-1}$

$F(x)=FL(x^2)+xFR(x^2)$


带入那些神奇的单位根之后,
发现有:

$0<=k<n/2$

$F(\omega_n^k)=Fl(\omega_{n/2}^k)+\omega_{n}^kFR(\omega_{n/2}^k)$


$F(\omega_n^{k+n/2})=Fl(\omega_{n/2}^k)-\omega_{n}^kFR(\omega_{n/2}^k)$

我们只要知道Fl、FR多项式在那n/2个位置的点值,那么就可以知道F那n个位置的点值了。

分治就可以处理出来。



IDFT:

经过一系列矩阵的运算之后,,,,

可以得到:

$b_k=[(ω_n^{-k})^0y_0+(ω_n^{-k})^1y_1+(ω_n^{-k})^2y_2+...+(ω_n^k)^{n-1}y_{n-1}]/n$

可以把y当做系数,

只要知道,当x是一系列w的时候,值是多少。

那么就求出来了$b_k$

FFT再写一遍。

注意这里带入的是$ω_n^{-k}$

开始的$tmp$有所不同

 

// luogu-judger-enable-o2
#include<bits/stdc++.h>
#define reg register int
#define il inline
#define numb (ch^'0')
using namespace std;
typedef long long ll;
il void rd(int &x){
    char ch;x=0;bool fl=false;
    while(!isdigit(ch=getchar()))(ch=='-')&&(fl=true);
    for(x=numb;isdigit(ch=getchar());x=x*10+numb);
    (fl==true)&&(x=-x);
}
namespace Miracle{
const int N=1e6+5;
const double Pi=acos(-1);
struct node{
    double x,y;
    node(){}
    node(double xx,double yy){x=xx,y=yy;}
    node operator +(const node &b){
        return node(x+b.x,y+b.y);
    }
    node operator -(const node &b){
        return node(x-b.x,y-b.y);
    }
    node operator *(const node &b){
        return node(x*b.x-y*b.y,x*b.y+y*b.x);
    }
}a[4*N],b[4*N];
int n,m;
int r[4*N];
void FFT(node *f,short op){
    for(reg i=0;i<n;++i){
        if(i<r[i]){
            node tmp=f[i];
            f[i]=f[r[i]];
            f[r[i]]=tmp;
        }
    }
    for(reg p=2;p<=n;p<<=1){
        int len=p/2;
        node tmp(cos(Pi/len),op*sin(Pi/len));
        for(reg k=0;k<n;k+=p){
            node buf(1,0);
            for(reg l=k;l<k+len;++l){
                node tt=buf*f[l+len];
                f[l+len]=f[l]-tt;
                f[l]=f[l]+tt;
                buf=buf*tmp;
            }
        }
    }
}
int main(){
    scanf("%d%d",&n,&m);
    for(reg i=0;i<=n;++i) scanf("%lf",&a[i].x);
    for(reg i=0;i<=m;++i) scanf("%lf",&b[i].x);
    for(m=n+m,n=1;n<=m;n<<=1);
    for(reg i=0;i<n;++i){
        r[i]=r[i>>1]>>1|((i&1)?n>>1:0);
    }
    FFT(a,1);FFT(b,1);
    for(reg i=0;i<n;++i) b[i]=a[i]*b[i];
    FFT(b,-1);
    for(reg i=0;i<=m;++i) printf("%.0lf ",fabs(b[i].x)/n);
    return 0;
}

}
int main(){
    Miracle::main();
    return 0;
}

/*
   Author: *Miracle*
   Date: 2018/11/21 8:05:13
*/
多项式乘法

 

 

关键点就是在于,用了单位根这个东西,可以避免平方、避免爆long long 以及精度损失的情况下,再利用乘法分配律,可以O(nlogn)得到多项式的点值表达。

 

例题:

P3338 [ZJOI2014]力

 

思路:要用FFT,必然要化成多项式卷积的形式

即形如:$h[j]=\sum_{i=0}^j f[i]*g[j-i]$

这样的话,我们把f,g分别作为两个多项式的系数,那么,发现,h[j]的值,恰好是f,g两个多项式乘积得到的多项式的第j+1项的系数。(考虑次数j是怎么来的)

就可以FFT优化这个n^2的算式了。

 

对于这个题:

令$f[i]=q[i]$,$g[i]=\frac{1}{i*i}$

特别的;有$g[0]=0$

则有$E[j]=\sum_{i=0}^jf[i]*g[j-i]-\sum_{i=j}^nf[i]*g[i-j]$

我们可以分开算,

后面的减法部分类似一个后缀,把$f$数组$reverse$一下,就变成了前缀了。$g$数组不用,因为距离要保持这样。

于是;

$E[j]=\sum_{i=0}^jf[i]*g[j-i]-\sum_{i=0}^{n-j}f'[i]*g[n-j-i]$

两次$FFT$即可

 

值得注意的是:

1.g数组赋值的时候,i*i可能会爆int,导致精度误差。所以,写成1/i/i比1/(i*i)要好得多。(30pts->100pts)

2.乘积多项式一定要n+n项都算出来,因为最后的插值和每一项的点值都有关系。即使我们只关心前n项。

 

代码:

#include<bits/stdc++.h>
#define reg register int
#define il inline
#define numb (ch^'0')
using namespace std;
typedef long long ll;
il void rd(int &x){
    char ch;x=0;bool fl=false;
    while(!isdigit(ch=getchar()))(ch=='-')&&(fl=true);
    for(x=numb;isdigit(ch=getchar());x=x*10+numb);
    (fl==true)&&(x=-x);
}
namespace Miracle{
const int N=200000+5;
const double Pi=acos(-1);
struct node{
    double x,y;
    node(){}
    node(double xx,double yy){
        x=xx;y=yy;
    }
}f[2*N],g[2*N],h[2*N];
double q[2*N];
int r[2*N];
int n,m;
node operator+(const node &a,const node &b){
    return node(a.x+b.x,a.y+b.y);
}
node operator-(const node &a,const node &b){
    return node(a.x-b.x,a.y-b.y);
}
node operator*(const node &a,const node &b){
    return node(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);
}
void FFT(node *f,short op){
    for(reg i=0;i<n;++i){
        if(i<r[i]){
            node tmp=f[i];
            f[i]=f[r[i]];
            f[r[i]]=tmp;
        }
    }
    for(reg p=2;p<=n;p<<=1){
        int len=p/2;
        node tmp=node(cos(Pi/len),op*sin(Pi/len));
        for(reg k=0;k<n;k+=p){
            node buf=node(1,0);
            for(reg l=k;l<k+len;++l){
                node tt=buf*f[l+len];
                f[l+len]=f[l]-tt;
                f[l]=f[l]+tt;
                buf=buf*tmp;
            }
        }
    }
}
int main(){
    scanf("%d",&m);
    for(reg i=1;i<=m;++i){
        scanf("%lf",&q[i]);
        if(i)g[i]=node((double)1/(double)i/(double)i,0);
    }
    for(n=1;n<=2*m;n<<=1);
    //cout<<" nn "<<n<<endl;
    for(reg i=0;i<n;++i){
        f[i]=node(q[i],0);
        //cout<<f[i].x<<" ";
    }
    
    //g[0]=node(0,0);
    for(reg i=0;i<n;++i){
        r[i]=(r[i>>1]>>1)|((i&1)?(n>>1):0);
    }
    
    FFT(f,1);
    FFT(g,1);
    for(reg i=0;i<n;++i) f[i]=g[i]*f[i];
    FFT(f,-1);
    
    
    
    reverse(q+1,q+n);
    for(reg i=0;i<n;++i){
        h[i]=node(q[i],0);
    }
    FFT(h,1);
    for(reg i=0;i<n;++i) h[i]=h[i]*g[i];
    FFT(h,-1);
    
    for(reg i=1;i<=m;++i){
        node ans=f[i]-h[n-i];
        printf("%lf\n",ans.x/n);
    }
    return 0;
}

}
int main(){
    Miracle::main();
    return 0;
}

/*
   Author: *Miracle*
   Date: 2018/11/21 10:17:15
*/

 

FFT优化高精乘法:

把数字看成系数,把10^k看做是x^k,那么就可以得到多项式。

这两个多项式相乘,得到的多项式,各个系数通过进位变成个位数之后,直接输出系数即可。

值得注意的是:

浮点数四舍五入赋值:

$a=floor(b+0.5);$

#include<bits/stdc++.h>
#define reg register int
#define il inline
#define numb (ch^'0')
using namespace std;
typedef long long ll;
il void rd(int &x){
    char ch;x=0;bool fl=false;
    while(!isdigit(ch=getchar()))(ch=='-')&&(fl=true);
    for(x=numb;isdigit(ch=getchar());x=x*10+numb);
    (fl==true)&&(x=-x);
}
namespace Miracle{
const int N=60000+5;
const double Pi=acos(-1);
struct node{
    double x,y;
    node(){}
    node(double xx,double yy){
        x=xx;y=yy;
    }
}a[4*N],b[4*N];
char p[N],q[N];
int c[4*N];
int n,m;
int r[4*N];
node operator+(node a,node b){
    return node(a.x+b.x,a.y+b.y);
}
node operator-(node a,node b){
    return node(a.x-b.x,a.y-b.y);
}
node operator*(node a,node b){
    return node(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);
}
void FFT(node *f,short op){
    for(reg i=0;i<n;++i){
        if(i<r[i]){
            node tmp=f[i];
            f[i]=f[r[i]];
            f[r[i]]=tmp;
        }
    }
    for(reg p=2;p<=n;p<<=1){
        int len=p/2;
        node tmp=node(cos(Pi/len),op*sin(Pi/len));
        for(reg k=0;k<n;k+=p){
            node buf=node(1,0);
            for(reg l=k;l<k+len;++l){
                node tt=buf*f[l+len];
                f[l+len]=f[l]-tt;
                f[l]=f[l]+tt;
                buf=buf*tmp;
            }
        }
    }
}
int main(){
    scanf("%d",&m);
    scanf("%s",p);scanf("%s",q);
    for(reg i=0;i<m;++i){
        a[m-i-1].x=p[i]-'0';
        b[m-i-1].x=q[i]-'0';
    }
    for(m=m+m,n=1;n<m;n<<=1);
    for(reg i=0;i<n;++i){
        r[i]=r[i>>1]>>1|((i&1)?n>>1:0);
    }
    FFT(a,1);FFT(b,1);
    
    for(reg i=0;i<n;++i) b[i]=a[i]*b[i];
    FFT(b,-1);
    for(reg i=0;i<n;++i){
        c[i]=floor(b[i].x/n+0.5);
    }
    
    int x=0;
    for(reg i=0;i<n;++i){
        c[i]+=x;
        x=(int)c[i]/10;
        c[i]%=10;
    }
    while(c[n-1]==0&&n>=1) --n;
    for(reg i=n-1;i>=0;--i){
        printf("%d",c[i]);
    }
    return 0;
}

}
int main(){
    Miracle::main();
    return 0;
}

/*
   Author: *Miracle*
   Date: 2018/11/21 16:30:14
*/
FFT高精

 

posted @ 2018-11-21 12:09  *Miracle*  阅读(475)  评论(0编辑  收藏  举报