Loading

FFT & NTT

FFT 快速傅里叶变换

\(O(nlogn)\) 计算多项式乘法

参考博客

系数表示法 转换为 点值表示法

\[\omega_n^k = cos(\dfrac {2\pi\cdot k} n) + i \cdot sin(\dfrac {2\pi \cdot k} n) \]

\[A(x)=a_0+a_1*x+a_2*{x^2}+a_3*{x^3}+a_4*{x^4}+a_5*{x^5}+\\ \dots+a_{n-2}*x^{n-2}+a_{n-1}*x^{n-1} \]

\[A(x)=(a_0+a_2*{x^2}+a_4*{x^4}+\dots+a_{n-2}*x^{n-2})+\\(a_1*x+a_3*{x^3}+a_5*{x^5}+ \dots+a_{n-1}*x^{n-1}) \]

\[A_1(x)=a_0+a_2*{x}+a_4*{x^2}+\dots+a_{n-2}*x^{\frac{n}{2}-1} \]

\[A_2(x)=a_1+a_3*{x}+a_5*{x^2}+ \dots+a_{n-1}*x^{\frac{n}{2}-1} \]

\[A(x)=A_1(x^2)+xA_2(x^2) \]

带入 \(x = \omega_n^k\)

\[A(\omega_n^k) = A_1(\omega_{\frac n2}^k) + \omega_n^kA_2(\omega_{\frac n2}^k) \]

带入 $x = \omega_n^{k+\frac n2} $

\[A(\omega_n^{k+\frac n2}) = A_1(\omega_{\frac n2}^k) -\omega_n^kA_2(\omega_{\frac n2}^k) \]

也就是说如果知道了 $A_1(x),A_2(x) $ 分别在 \(\omega_{\frac n2}^0\) , \(\omega_{\frac n2}^1\) , \(\omega_{\frac n2}^2\) ,...,\(\omega_{\frac n2}^{\frac n2 -1}\) 的取值,

就可以 \(O(n)\) 的求出 \(A(x)\)

void fft(cp *a,int n,int inv)//inv是取共轭复数的符号
{
    if (n==1)return;
    int mid=n/2;
    static cp b[MAXN];
    for(int i = 0;i < mid;i++)b[i]=a[i*2],b[i+mid]=a[i*2+1];
    
    for(int i = 0;i < n;i++)a[i]=b[i];
    fft(a,mid,inv),fft(a+mid,mid,inv);//分治
    
    for(int i = 0;i < mid;i++)
    {
        cp x(cos(2*pi*i/n),inv*sin(2*pi*i/n));//inv取决是否取共轭复数
        b[i]=a[i]+x*a[i+mid],b[i+mid]=a[i]-x*a[i+mid];
    }
    for(int i = 0;i < a;i++)a[i]=b[i];
}

每个位置分治后最终的位置是二进制翻转后的位置

void fft(cp *a,int n,int inv)
{
    int bit=0;
    while ((1<<bit)<n)bit++;
    fo(i,0,n-1)
    {
        rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
        if (i<rev[i])swap(a[i],a[rev[i]]);//不加这条if会交换两次(就是没交换)
    }
    for (int mid=1;mid<n;mid*=2)//mid是准备合并序列的长度的二分之一
    {
    	cp temp(cos(pi/mid),inv*sin(pi/mid));//单位根,pi的系数2已经约掉了
        for (int i=0;i<n;i+=mid*2)//mid*2是准备合并序列的长度,i是合并到了哪一位
		{
            cp omega(1,0);
            for (int j=0;j<mid;j++,omega*=temp)//只扫左半部分,得到右半部分的答案
            {
                cp x=a[i+j],y=omega*a[i+j+mid];
                a[i+j]=x+y,a[i+j+mid]=x-y;//这个就是蝴蝶变换什么的
            }
        }
    }
}

洛谷模板

注意 lim

#include<bits/stdc++.h>
using namespace std;

const double pi = acos(-1.0);
const int N = 3e6 + 10;

struct cp {
	double x, y;
	cp() {}
	cp(double _x, double _y) {
		x = _x; y = _y;
	}
	cp operator + (cp b) {
		return cp(x + b.x, y + b.y);
	}
	cp operator -(cp b) {
		return cp(x - b.x, y - b.y);
	}
	cp operator *(cp b) {
		return cp(x * b.x - y * b.y, x * b.y + y * b.x);
	}
};
int rev[N];
int bit = 0;
int lim;
void FFT(cp* a, int inv) {
	
	for (int i = 0; i < lim; i++) {
		if (i < rev[i]) {
			swap(a[i], a[rev[i]]);
		}
	}
	
	for (int mid = 1; mid < lim; mid <<= 1) {
		cp temp(cos(pi / mid), inv * sin(pi / mid));
		for (int i = 0; i < lim; i += mid * 2) {
			cp omega(1, 0);
			for (int j = 0; j < mid; j++, omega = omega * temp) {
				cp x = a[i + j], y = omega * a[i + j + mid];
				a[i + j] = x + y, a[i + j + mid] = x - y;
			}
		}
	}
}

int n, m;

cp A[N], B[N];

int main() {
	scanf("%d%d", &n, &m);
	
	lim = 1;
	while (lim <= n + m)lim<<=1,bit++;//调整至 2^k

	for (int i = 0; i < lim; i++) {
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
	}
	for (int i = 0; i <= n; i++)scanf("%lf", &A[i].x), A[i].y = 0;
	for (int i = 0; i <= m; i++)scanf("%lf", &B[i].x), B[i].y = 0;

	FFT(A, 1);
	FFT(B, 1);
	for (int i = 0; i <= lim; i++) {
		A[i] = A[i] * B[i];
	}
	FFT(A, -1);
	for (int i = 0; i <= n + m; i++) {
		printf("%d ", int(A[i].x /lim+0.5));
	}

}

NTT

参考博客

原根

还没有整太明白

待补,丢一个板子

#include<bits/stdc++.h>
#define swap(a,b) (a^=b,b^=a,a^=b)
using namespace std;

#define LL long long 
const int MAXN = 3 * 1e6 + 10, P = 998244353, G = 3, Gi = 332748118;
char buf[1 << 21], * p1 = buf, * p2 = buf;

int N, M, limit = 1, L, r[MAXN];
LL a[MAXN], b[MAXN];
inline LL fastpow(LL a, LL k) {
	LL base = 1;
	while (k) {
		if (k & 1) base = (base * a) % P;
		a = (a * a) % P;
		k >>= 1;
	}
	return base % P;
}
inline void NTT(LL* A, int type) {
	for (int i = 0; i < limit; i++)
		if (i < r[i]) swap(A[i], A[r[i]]);
	for (int mid = 1; mid < limit; mid <<= 1) {
		LL Wn = fastpow(type == 1 ? G : Gi, (P - 1) / (mid << 1));
		for (int j = 0; j < limit; j += (mid << 1)) {
			LL w = 1;
			for (int k = 0; k < mid; k++, w = (w * Wn) % P) {
				int x = A[j + k], y = w * A[j + k + mid] % P;
				A[j + k] = (x + y) % P,
					A[j + k + mid] = (x - y + P) % P;
			}
		}
	}
}
int main() {
	scanf("%d%d", &N, &M);
	for (int i = 0; i <= N; i++) scanf("%d", a + i);
	for (int i = 0; i <= M; i++) scanf("%d", b + i);

	while (limit <= N + M) limit <<= 1, L++;
	for (int i = 0; i < limit; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));
	NTT(a, 1); NTT(b, 1);
	for (int i = 0; i < limit; i++) a[i] = (a[i] * b[i]) % P;
	NTT(a, -1);
	LL inv = fastpow(limit,	 P - 2);
	for (int i = 0; i <= N + M; i++)
		printf("%d ", (a[i] * inv) % P);
	return 0;
}
posted @ 2020-10-06 21:30  —O0oO-  阅读(176)  评论(0编辑  收藏  举报