FFT

 这里是板子,虽然看懂了原理,但是代码还是好难理解哦
void fft(int n,complex<double>*buffer,int offset,int step,complex<double>* epsilon)     
{                                                                                       
if(n==1) return;                                                                       
int m=n>>1;                                                                            
fft(m,buffer,offset,step<<1,epsilon);                                                  
fft(m,buffer,offset+step,step<<1,epsilon);                                             
for(int k=0;k!=m;++k)                                                                  
{                                                                                      
int pos=2*step*k;                                                                     
temp[k]=buffer[pos+offset]+epsilon[k*step]*buffer[pos+offset+step];                   
temp[k+m]=buffer[pos+offset]-epsilon[k*step]*buffer[pos+offset+step];                 
}                                                                                      
for(int i=0;i!=n;++i)                                                                  
buffer[i*step+offset]=temp[i];                                                        
}                                                                                       
void init_epsilon(int n)                                                                
{                                                                                       
double pi=acos(-1);                                                                    
for(int i=0;i!=n;++i)                                                                  
{                                                                                      
epsilon[i]=complex<double>(cos(2.0*pi*i/n),sin(2.0*pi*i/n));                          
arti_epsilon[i]=conj(epsilon[i]);                                                     
}                                                                                      
}                                                                                       
int reverse_add(int x)                                                                  
{                                                                                       
for(int l=1<<bit_length;(x^=l)<l;l>>=1);                                               
return x;                                                                              
}                                                                                       
/* 这时候 n 已经补齐到 2 的幂次 */                                                      
void bit_reverse(int n, complex_t *x)                                                   
{                                                                                       
for(int i=0,j=0;i!=n;++i)                                                              
{                                                                                      
if(i>j) swap(x[i],x[j]);                                                              
for(int l=n>>1;(j^=l)<l;l>>=1);                                                       
}                                                                                      
}                                                                                       
void transform(int n,complex_t *x,complex_t *w)                                         
{                                                                                       
bit_reverse(n, x);                                                                     
for(int i=2;i<=n;i<<=1)                                                                
{                                                                                      
int m=i>>1;                                                                           
for(int j=0;j<n;j+=i)                                                                 
{                                                                                     
for(int k=0;k!=m;++k)                                                                
{                                                                                    
complex_t z=x[j+m+k]*w[n/i*k];                                                      
x[j+m+k]=x[j+k]-z;                                                                  
x[j+k]+=z;                                                                          
}                                                                                    
}                                                                                     
}                
}
 这个是带注释的版本
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<iostream>
#include<algorithm>
using namespace std;
#define N 301000
const double pi=acos(-1);
struct node
{
    double x,y;
    node(){x=y=0;}
    node(double x,double y):x(x),y(y){}
}a[N],b[N];
node operator + (node x,node y) {return node(x.x+y.x,x.y+y.y);}
node operator - (node x,node y) {return node(x.x-y.x,x.y-y.y);}
node operator * (node x,node y) {return node(x.x*y.x-x.y*y.y,x.x*y.y+x.y*y.x);}
void fft(node *s,int n,int t)
{
    if (n==1) return;
    node a0[n>>1],a1[n>>1];
    for (int i=0;i<=n;i+=2) 
     a0[i>>1]=s[i],a1[i>>1]=s[i+1];
    fft(a0,n>>1,t);fft(a1,n>>1,t);
    node wn(cos(2*pi/n),t*sin(2*pi/n)),w(1,0);
    for (int i=0;i<(n>>1);i++,w=w*wn) 
     s[i]=a0[i]+w*a1[i],s[i+(n>>1)]=a0[i]-w*a1[i];
}
int main()
{
    int n,m,fn,i;
    scanf("%d%d",&n,&m);
    for (i=0;i<=n;i++) scanf("%lf",&a[i].x);
    for (i=0;i<=m;i++) scanf("%lf",&b[i].x);
    fn=1;while (fn<=n+m) fn<<=1;
    fft(a,fn,1);fft(b,fn,1);
    for (i=0;i<=fn;i++) a[i]=a[i]*b[i];
    fft(a,fn,-1);
    for (i=0;i<=n+m;i++) printf("%d ",(int)(a[i].x/fn+0.5));
    printf("\n");
    return 0;
}
posted @ 2017-04-18 18:34  OcahIBye  阅读(195)  评论(0编辑  收藏  举报