SCL--FFT & NTT
2015-07-28 02:26:59
总结:
(1)关于FFT,算法导论讲得非常好,证明详细,课后题很有启发性。这里先上花了一天多时间综合总结起来的模板,效率还可。
大致流程有倍次,求值,乘法,插值。求值点是 n 次单位复数根。
(2)关于NTT,推荐一下 Acdreamer 的 blog ,有关原根的知识:blog
然后就是FFT整数域中模运算下的形式。求值点变成了模数P的原根的次幂,g^((P-1)/t),t = 2^k
模数的取法:(479 << 21) + 1 or 998244353(2^23 * 7 * 17),其原根为3 。
以UOJ的多项式乘法为测试平台
NTT:
const int P = (479 << 21) + 1; //费马素数 const int G = 3; //原根 const int MAXN = (1 << 18) + 10; const int NUM = 20; int rev[MAXN]; int A1[MAXN],A2[MAXN],wn[2][NUM]; int n,m,n3,N,bit; int Q_pow(int x,int y,int mod){ int res = 1; x %= mod; while(y){ if(y & 1) res = 1ll * res * x % mod; x = 1ll * x * x % mod; y >>= 1; } return res; } void Pre_cal(){ n3 = n + m - 1; //结果多项式的次数界 memset(rev,0,sizeof(rev)); for(N = 1,bit = 0; N < n3; N <<= 1,++bit); //DFT底层 for(int i = 1; i < N; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1)); for(int i = 0; i < NUM; ++i){ int t = 1 << i; wn[0][i] = Q_pow(G,(P - 1) / t,P); //预处理求值点 wn[1][i] = Q_pow(wn[0][i],P - 2,P); //求值点逆元 } } void NTT(int *A,int n,int f){ for(int i = 0; i < n; ++i) if(i < rev[i]) swap(A[i],A[rev[i]]); int id = (f == -1) ? 1 : 0,p = 1; for(int m = 2; m <= n; m <<= 1,++p){ //m次单位根 for(int k = 0; k < n; k += m){ //遍历每一块 for(int j = k,w = 1; j < k + (m >> 1); ++j){ //折半 int t = 1ll * w * A[j + (m >> 1)] % P; //右项 int u = A[j] % P; //左项 (此处取模待商讨) if((A[j] = u + t) >= P) A[j] -= P; if((A[j + (m >> 1)] = u - t) < 0) A[j + (m >> 1)] += P; w = 1ll * w * wn[id][p] % P; } } } if(f == -1){ int inv = Q_pow(n,P - 2,P); for(int i = 0; i < n; ++i) A[i] = 1ll * A[i] * inv % P; } } void Solve(){ //流程:倍次,求值,乘法,插值 Pre_cal(); NTT(A1,N,1); //求值 NTT(A2,N,1); for(int i = 0; i < N; ++i) A1[i] = 1ll * A1[i] * A2[i] % P; //注意ll NTT(A1,N,-1); //插值 for(int i = 0; i < n3 - 1; ++i) printf("%d ",A1[i]); printf("%d\n",A1[n3 - 1]); } int main(){ scanf("%d%d",&n,&m); n++,m++; //次数界 for(int i = 0; i < n; ++i) scanf("%d",A1 + i); for(int i = 0; i < m; ++i) scanf("%d",A2 + i); Solve(); return 0; }
FFT:
const int MAXN = (1 << 18) + 10; const double DPI = 2.0 * acos(-1.0); int n,m,n3,N,bit; int rev[MAXN]; struct CP{ //复数类 double a,b; CP(double ta = 0,double tb = 0) : a(ta) , b(tb) {} }A1[MAXN],A2[MAXN]; inline CP operator * (CP &a,CP &b){return CP(a.a*b.a-a.b*b.b,a.a*b.b+a.b*b.a);} inline CP operator + (CP &a,CP &b){return CP(a.a + b.a,a.b + b.b);} inline CP operator - (CP &a,CP &b){return CP(a.a - b.a,a.b - b.b);} void Pre_cal(){ n3 = n + m - 1; //结果多项式次数界 memset(rev,0,sizeof(rev)); for(N = 1,bit = 0; N < n3; N <<= 1,++bit); for(int i = 1; i < N; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1)); } void FFT(CP *A,int n,int f){ for(int i = 0; i < n; ++i) if(i < rev[i]) swap(A[i],A[rev[i]]); for(int m = 2; m <= n; m <<= 1){ //DFT结果的项数 CP wm(cos(DPI / m),f * sin(DPI / m)); //m次单位复数根 for(int k = 0; k < n; k += m){ //遍历每一块 CP w(1,0); for(int j = k; j < k + (m >> 1); ++j){ //折半,关键 CP t = w * A[j + (m >> 1)]; //右项 CP u = A[j]; //左项 A[j] = u + t; A[j + (m >> 1)] = u - t; w = w * wm; //更新w } } } if(f == -1) for(int i = 0; i < n; ++i) A[i].a /= n; } void Solve(){ //流程:倍次,求值,乘法,插值 Pre_cal(); FFT(A1,N,1); //求值1 FFT(A2,N,1); //求值2 for(int i = 0; i < N; ++i) A1[i] = A1[i] * A2[i]; FFT(A1,N,-1); //插值 for(int i = 0; i < n3 - 1; ++i) printf("%d ",(int)(A1[i].a + 0.5)); printf("%d\n",(int)(A1[n3 - 1].a + 0.5)); } int main(){ scanf("%d%d",&n,&m); n++; m++; for(int i = 0; i < n; ++i) scanf("%lf",&A1[i].a),A1[i].b = 0; for(int i = 0; i < m; ++i) scanf("%lf",&A2[i].a),A2[i].b = 0;; Solve(); return 0; }