多项式FFT/NTT模板(含乘法/逆元/log/exp/求导/积分/快速幂)

自己整理出来的模板

存在的问题:

1.多项式求逆常数过大(尤其是浮点数FFT)

2.log只支持f[0]=1的情况,exp只支持f[0]=0的情况

有待进一步修改和完善

FFT:

  1 #include<bits/stdc++.h>
  2 using namespace std;
  3 typedef long long ll;
  4 typedef double db;
  5 const db pi=acos(-1);
  6 const int N=4e5+10,M=1e6+10,mod=998244353;
  7 int n,m,n2,a[N];
  8 int Pow(int x,int p) {
  9     int ret=1;
 10     for(; p; p>>=1,x=(ll)x*x%mod)if(p&1)ret=(ll)ret*x%mod;
 11     return ret;
 12 }
 13 struct P {
 14     db x,y;
 15     P operator+(const P& b) {return {x+b.x,y+b.y};}
 16     P operator-(const P& b) {return {x-b.x,y-b.y};}
 17     P operator*(const P& b) {return {x*b.x-y*b.y,x*b.y+y*b.x};}
 18     P operator/(db b) {return {x/b,y/b};}
 19     P cj() {return {x,-y};}
 20 };
 21 struct F_FT {
 22     P A[N],B[N],w[N];
 23     int b[N],c[N],d[N],e[N],f[N];
 24     void FFT(P* a,int n,int f) {
 25         for(int i=1,j=n>>1,k; i<n-1; ++i,j^=k) {
 26             if(i<j)swap(a[i],a[j]);
 27             for(k=n>>1; j&k; j^=k,k>>=1);
 28         }
 29         for(int i=0; i<n; ++i)w[i]= {cos(2*pi*i/n),sin(2*pi*i/n)};
 30         for(int k=1; k<n; k<<=1)
 31             for(int i=0; i<n; i+=k<<1)
 32                 for(int j=i; j<i+k; ++j) {
 33                     P W= {w[n/2/k*(j-i)].x,~f?w[n/2/k*(j-i)].y:-w[n/2/k*(j-i)].y};
 34                     P x=a[j],y=W*a[j+k];
 35                     a[j]=x+y,a[j+k]=x-y;
 36                 }
 37         if(!~f)for(int i=0; i<n; ++i)a[i]=a[i]/n;
 38     }
 39     void mul(int* a,int* b,int* c,int n) {
 40         for(int i=0; i<n; ++i)A[i]= {a[i]>>15,a[i]&((1<<15)-1)},B[i]= {b[i]>>15,b[i]&((1<<15)-1)},A[i+n]= {0,0},B[i+n]= {0,0};
 41         n<<=1;
 42         FFT(A,n,1),FFT(B,n,1);
 43         for(int i=0; i<=n/2; ++i) {
 44             int j=(n-i)&(n-1);
 45             P a1=(A[i]+A[j].cj())* (P) {0.5,0},b1=(A[i]-A[j].cj())* (P) {0,-0.5};
 46             P a2=(B[i]+B[j].cj())* (P) {0.5,0},b2=(B[i]-B[j].cj())* (P) {0,-0.5};
 47             P a3=(A[j]+A[i].cj())* (P) {0.5,0},b3=(A[j]-A[i].cj())* (P) {0,-0.5};
 48             P a4=(B[j]+B[i].cj())* (P) {0.5,0},b4=(B[j]-B[i].cj())* (P) {0,-0.5};
 49             A[i]=a1*a2+b1*b2* (P) {0,1},B[i]=a1*b2+a2*b1;
 50             A[j]=a3*a4+b3*b4* (P) {0,1},B[j]=a3*b4+a4*b3;
 51         }
 52         FFT(A,n,-1),FFT(B,n,-1);
 53         for(int i=0; i<n; ++i)c[i]=(((ll(A[i].x+0.5)%mod)<<30)+ll(A[i].y+0.5)+(ll(B[i].x+0.5)<<15))%mod;
 54     }
 55     void inverse(int* a,int n) {
 56         for(int i=0; i<n; ++i)b[i]=0;
 57         b[0]=Pow(a[0],mod-2);
 58         for(int m=2; m<=n; m<<=1) {
 59             mul(b,b,c,m),mul(a,c,c,m);
 60             for(int i=0; i<m; ++i)b[i]=(((ll)b[i]*2-c[i])%mod+mod)%mod;
 61         }
 62         for(int i=0; i<n; ++i)a[i]=b[i];
 63     }
 64     void der(int* a,int n) {for(int i=1; i<n; ++i)a[i-1]=(ll)i*a[i]%mod; a[n-1]=0;}
 65     void itg(int* a,int n) {for(int i=n-2; i>=0; --i)a[i+1]=(ll)Pow(i+1,mod-2)*a[i]%mod; a[0]=0;}
 66     void log(int* a,int n) {
 67         for(int i=0; i<n; ++i)d[i]=a[i];
 68         inverse(d,n),der(a,n),mul(a,d,a,n),itg(a,n);
 69     }
 70     void exp(int* a,int n) {
 71         for(int i=0; i<n; ++i)e[i]=0;
 72         e[0]=1;
 73         for(int m=2; m<=n; m<<=1) {
 74             for(int i=0; i<m; ++i)f[i]=e[i];
 75             log(f,m);
 76             for(int i=0; i<m; ++i)f[i]=(a[i]-f[i]+mod)%mod;
 77             f[0]++;
 78             mul(e,f,e,m);
 79         }
 80         for(int i=0; i<n; ++i)a[i]=e[i];
 81     }
 82     void pow(int* a,int n,int p) {
 83         int j=0;
 84         for(; j<n&&!a[j]; ++j);
 85         if(j==n)return;
 86         int px=Pow(a[j],p),invx=Pow(a[j],mod-2);
 87         for(int i=j; i<n; ++i)a[i-j]=(ll)a[i]*invx%mod;
 88         for(int i=n-j; i<n; ++i)a[i]=0;
 89         log(a,n);
 90         for(int i=0; i<n; ++i)a[i]=(ll)a[i]*p%mod;
 91         exp(a,n);
 92         for(int i=n-1; i>=(ll)j*p; --i)a[i]=(ll)a[i-j*p]*px%mod;
 93         for(int i=0; i<n&&i<(ll)j*p; ++i)a[i]=0;
 94     }
 95 } fft;
 96 int main() {
 97     scanf("%d%d",&n,&m);
 98     for(int i=0; i<n; ++i)scanf("%d",&a[i]),a[i]%=mod;
 99     for(n2=1; n2<n; n2<<=1);
100     fft.pow(a,n2,m);
101     for(int i=0; i<n; ++i)printf("%d%c",a[i]," \n"[i==n-1]);
102     return 0;
103 }
View Code

NTT:

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 typedef long long ll;
 4 const int N=4e5+10,M=1e6+10,mod=998244353;
 5 const int G=3;
 6 int n,m,n2,a[N];
 7 int Pow(int x,int p) {
 8     int ret=1;
 9     for(; p; p>>=1,x=(ll)x*x%mod)if(p&1)ret=(ll)ret*x%mod;
10     return ret;
11 }
12 struct F_FT {
13     int A[N],B[N],b[N],c[N],d[N],e[N],f[N];
14     void FFT(int* a,int n,int f) {
15         for(int i=1,j=n>>1,k; i<n-1; ++i,j^=k) {
16             if(i<j)swap(a[i],a[j]);
17             for(k=n>>1; j&k; j^=k,k>>=1);
18         }
19         for(int k=1; k<n; k<<=1) {
20             int gn=Pow(G,(mod-1)/(k<<1));
21             if(f==-1)gn=Pow(gn,mod-2);
22             for(int i=0; i<n; i+=k<<1) {
23                 int g=1;
24                 for(int j=i; j<i+k; ++j,g=(ll)g*gn%mod) {
25                     int x=a[j],y=(ll)g*a[j+k]%mod;
26                     a[j]=((ll)x+y)%mod,a[j+k]=((ll)x-y+mod)%mod;
27                 }
28             }
29         }
30         if(!~f) {
31             int invn=Pow(n,mod-2);
32             for(int i=0; i<n; ++i)a[i]=(ll)a[i]*invn%mod;
33         }
34     }
35     void mul(int* a,int* b,int* c,int n) {
36         for(int i=0; i<n; ++i)A[i]=a[i],B[i]=b[i],A[i+n]=B[i+n]=0;
37         n<<=1;
38         FFT(A,n,1),FFT(B,n,1);
39         for(int i=0; i<n; ++i)c[i]=(ll)A[i]*B[i]%mod;
40         FFT(c,n,-1);
41     }
42     void inverse(int* a,int n) {
43         for(int i=0; i<n; ++i)b[i]=0;
44         b[0]=Pow(a[0],mod-2);
45         for(int m=2; m<=n; m<<=1) {
46             for(int i=0; i<m; ++i)A[i]=a[i],B[i]=b[i],A[i+m]=B[i+m]=0;
47             FFT(A,m<<1,1),FFT(B,m<<1,1);
48             for(int i=0; i<(m<<1); ++i)b[i]=(((ll)B[i]*2-(ll)A[i]*B[i]%mod*B[i]%mod)%mod+mod)%mod;
49             FFT(b,m<<1,-1);
50             for(int i=m; i<(m<<1); ++i)b[i]=0;
51         }
52         for(int i=0; i<n; ++i)a[i]=b[i];
53     }
54     void der(int* a,int n) {for(int i=1; i<n; ++i)a[i-1]=(ll)i*a[i]%mod; a[n-1]=0;}
55     void itg(int* a,int n) {for(int i=n-2; i>=0; --i)a[i+1]=(ll)Pow(i+1,mod-2)*a[i]%mod; a[0]=0;}
56     void log(int* a,int n) {
57         for(int i=0; i<n; ++i)d[i]=a[i];
58         inverse(d,n),der(a,n),mul(a,d,a,n),itg(a,n);
59     }
60     void exp(int* a,int n) {
61         for(int i=0; i<n; ++i)e[i]=0;
62         e[0]=1;
63         for(int m=2; m<=n; m<<=1) {
64             for(int i=0; i<m; ++i)f[i]=e[i];
65             log(f,m);
66             for(int i=0; i<m; ++i)f[i]=(a[i]-f[i]+mod)%mod;
67             f[0]++;
68             mul(e,f,e,m);
69         }
70         for(int i=0; i<n; ++i)a[i]=e[i];
71     }
72     void pow(int* a,int n,int p) {
73         int j=0;
74         for(; j<n&&!a[j]; ++j);
75         if(j==n)return;
76         int px=Pow(a[j],p),invx=Pow(a[j],mod-2);
77         for(int i=j; i<n; ++i)a[i-j]=(ll)a[i]*invx%mod;
78         for(int i=n-j; i<n; ++i)a[i]=0;
79         log(a,n);
80         for(int i=0; i<n; ++i)a[i]=(ll)a[i]*p%mod;
81         exp(a,n);
82         for(int i=n-1; i>=(ll)j*p; --i)a[i]=(ll)a[i-j*p]*px%mod;
83         for(int i=0; i<n&&i<(ll)j*p; ++i)a[i]=0;
84     }
85 } fft;
86 int main() {
87     scanf("%d%d",&n,&m);
88     for(int i=0; i<n; ++i)scanf("%d",&a[i]),a[i]%=mod;
89     for(n2=1; n2<n; n2<<=1);
90     fft.pow(a,n2,m);
91     for(int i=0; i<n; ++i)printf("%d%c",a[i]," \n"[i==n-1]);
92     return 0;
93 }
View Code

代码源自洛谷P4238

posted @ 2019-09-10 17:22  jrltx  阅读(416)  评论(0编辑  收藏  举报