总结:FFT&NTT

FFT&NTT的小总结

前言

  • 最近正在学 \(FFT\)然后一脸迷惑
  • 我看我是完全不懂……

什么是 \(FFT\&NTT\)

在讲 \(FFT\&NTT\) 之前,我觉得我有必要先介绍一下 \(DFT\&IDFT\)

\(DFT\&IDFT\)

  • \(DFT\:(Discrete\:Fourier\:Transform)\) :离散傅里叶变换。而 \(IDFT\) 自然就是离散傅里叶逆变换了。
  • 离散傅里叶变换是傅里叶变换时域频域上都呈离散的形式,将信号的时域采样变换为其DTFT的频域采样。在形式上,变换两端(时域和频域上)的序列是有限长的,而实际上这两组序列都应当被认为是离散周期信号的主值序列。——百度百科
  • 当然,作为一名 \(oier\) ,我们并不需要知道离散傅里叶变换有何重要的物理意义(除非你想转物理),我们只需要知道,这个变换能够帮助我们将一个多项式以另外一种形式表达。
  • 让我们先考虑这样一个问题:假如有两个多项式,我们要去计算它们的乘积,只能去枚举一个多项式的每一个点,再和另外一个多项式的每一个点相乘,这样的时间复杂度是 \(O(n^2)\) 的。
  • 太慢了……有没有更好的算法?
  • 自然是有的(不然我问这个问题干嘛?),我们都知道,\(n\) 个点能够唯一确定一个 \(n-1\) 次的多项式,那么我们对于多项式就有了一种新的表示方法:点值表示法。然后我们假装惊奇地发现:在点值表示法下,两个多项式相乘的时间复杂度居然是 \(O(n)\) 的,只需要枚举每个横坐标,将他们的纵坐标相乘就好了。那么我们就有了一个 \(O(n)\) 的算法。像这样朴素将系数转换成点值算法就叫 \(DFT\),点值转系数叫 \(IDFT\)
  • 等等,明明已经能够在线性时间内算出结果,为什么还说是朴素算法?因为事实上,还有一个问题需要考虑,就是如何将一个多项式的系数表示变成点值表示。显而易见,随便带入 \(n\) 个点就可以了,但是这样转换的时间复杂度是 \(O(n^2)\) 的,果然很朴素
  • 如果我们能够快速地解决表示的转换,那么问题就迎刃而解了。这就需要用到 \(FFT\&NTT\)

\(FFT\&NTT\)

  • 如果你十分耐心地看到了这里并且以为我将会非常详细地去讲解 \(FFT\&NTT\) ,那你就错了,哈哈哈哈不好笑。网上讲 \(FFT\) 的比较多,本蒟蒻就不再去插一脚了,就只是稍微地讲一讲什么是 \(FFT\&NTT\) (免得像我一样学了半天了还不知道这玩意儿是干嘛的),以及个人认为 \(FFT\) 中比较重要的一些东西。
  • 这里推荐oi-wiki的官方教程和个人认为讲的比较好的几篇博客:十分简明易懂的FFTFFT详解
  • 于心不忍,我还是稍微讲讲吧。
  • \(FFT\:(Fast\:Fourier\:Transform)\) :快速傅里叶变换,是对离散傅里叶变换的优化。能够做到在 \(O(nlogn)\) 的时间复杂度内将一个多项式从点值表示法变成系数表示法。其算法原理是在复数平面内的单位圆上选取 \(2^n\) 个复数作为点的横坐标,代入多项式求得其纵坐标,(当时没学复数的时候一脸迷惑:这玩意儿还能带进去算?)。当然,这 \(2^n\) 个点可不是随便选的,这就要用到单位复数根的性质了:
    1. 周期性:\(\omega_N^{n+N}=\omega_N^n\)
    2. 对称性:\(\omega_N^{n+\frac{N}{2}}=-\omega_N^n\)
    3. 降次性:\(\omega_{kN}^{kn}=\omega_N^n\)
  • \(NTT\:(Number\:Theoretic\:Transform)\) :快速数论变换,是对 \(FFT\) 的一次改良。\(FFT\) 好是好,但是也是有一些缺点的,比如手写complex大量的 double 运算既会损失精度(世界上卡你精度的方法有千千万万种),又自带大常数。所以这个时候我们就只能用 \(NTT\) 了。
  • 仔细思考一下,\(FFT\) 的本质是什么?是什么让 \(FFT\) 能够做到 \(O(nlogn)\) 的复杂度?其实就是我之前提到的关于单位根的那些性质。那有没有什么其他的东西也拥有单位根的这些性质呢?答案是有的,原根就具有和单位根一样的性质,建议自己证明一遍。
  • 我们可以仿照单位复数根的形式,也将原根的取值看成一个圆,不过这个圆只有有限个点,每个点表达的是模数的剩余系中的值,可以证明其具有单位根的三个性质:
    1. 周期性:利用欧拉定理可得 \(r^k\) 会以 \(\varphi(p)\) 为一个循环节。
    2. 对称性:将 \(r_n^\frac{n}{2}\) 平方得 \(r_n^n\equiv r_n^0\equiv 1\:(mod\:p)\) ,由原根的性质得 \(r_n^\frac{n}{2}\not\equiv 1\:(mod\:p)\),所以 \(r_n^\frac{n}{2}\equiv -1\:(mod\:p)\)
    3. 降次性:显而易见
  • 当然,\(NTT\) 也是有自己的缺点的:比如不能够处理小数的情况,以及不能够处理没有模数的情况。不过,如果能够保证答案在模数的范围内倒是可以。对于模数的选取也有一定的要求,首先是必须要有原根,其次是必须要是 \(2\) 的较高幂次的倍数。

进入正题

终于讲到个人认为比较重要的一些东西了。。。

一:位逆序置换

  • 为什么要进行位逆序置换?因为直接递归太慢了,既耽误时间,又浪费空间。
  • 通过瞎猜,我们发现 \(FFT\) 的一个性质:系数表示法中的每个系数分治之后的最终位置为其当前位置二进制反转后的位置。那么我们就可以提前将系数的位置 swap 一下,然后一层一层地向上推进就可以了,代码很短,也就几行。
while(p<=sum) {p<<=1,bit++;}

for(int i=1;i<p;i++) {
	rev[i]=((rev[i>>1]>>1)|((i&1)<<(bit-1)));
}

for(int i=0;i<p;i++) {
	if(i<rev[i]) swap(t[i],t[rev[i]]);
}
  • 注意第一行为什么要取等于号,因为下标为了保证是2的整数次幂,所以是从 \(0\sim p-1\)

二:快速傅里叶逆变换

IDFT(傅里叶反变换)的作用,是把目标多项式的点值形式转换成系数形式。——oi-wiki

  • 这同样是整个 \(FFT\&NTT\) 算法中至关重要的一步,因为我们在点值表示法中将两个多项式相乘之后还是要将其变为系数表示法的。

  • 具体的做法是将单位根取倒数再除以 \(n\) 之后重复做一遍 \(FFT\&NTT\)

  • 使用构造法证明:设多项式 \(A(x)=\sum_0^{n-1}a_ix^i\) 经过 \(DFT\) 之后得到序列 \((b_0,b_1,b_2,\dots,b_{n-1})\) 。再分别将其作为多项式 \(B(x)\) 的系数,即 \(B(x)=\sum_0^{n-1}b_ix^i\)。这时我们用单位根的倒数对 \(B(x)\) 做一次 \(DFT\) 得到序列 \((c_0,c_1,c_2,\dots,c_{n-1})\) 。那么显然可得:

    \[\begin{aligned} c_k &=\sum_{i=0}^{n-1}b_i\left(\omega_n^{-k}\right)^i\\ &=\sum_{i=0}^{n-1}\left(\sum_{j=0}^{n-1}a_j(\omega_n^i)^j\right)\left(\omega_n^{-k}\right)^i\\ &=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}a_j\omega_n^{(j-k)\times i}\\ &=\sum_{j=0}^{n-1}a_j\sum_{i=0}^{n-1}\left(\omega_n^{j-k}\right)^i \end{aligned} \]

    此时我们设 \(S(\omega_n^a)=\sum_{i=0}^{n-1}\left(\omega_n^{j-k}\right)^i\)

    \(a=0\) 时,\(\omega_n^0=1\:\:\:\therefore S(\omega_n^a)=n\)

    \(a\neq 0\) 时,很容易想到等比数列的求和方法:错位相减法。所以

    \[\begin{aligned} S(\omega_n^a) &=\sum_{i=0}^{n-1}\left(\omega_n^a\right)^i\\ \omega_n^aS(\omega_n^a) &=\sum_{i=1}^n\left(\omega_n^a\right)^i\\ \therefore S(\omega_n^a) &=\frac{\left(\omega_n^a\right)^n-\left(\omega_n^a\right)^0}{\omega_n^a-1}=0 \end{aligned} \]

    \[S(\omega_n^a)=\left\{\begin{aligned}n &,a=0\\0 &,a\neq 0\end{aligned}\right. \]

    那么

    \[c_k=\sum_{J=0}^{n-1}a_jS(\omega_n^{j-k})=a_k\cdot n\\ \therefore a_k=\frac{c_k}{n} \]

证毕。

挂个代码:

\(FFT:\)

#include <iostream>
#include <cstdio>
#include <cmath>

using namespace std;

const int maxn=4e6+10;
const double pi=acos(-1);

int p=1,bit;
int rev[maxn];

struct complex {
	double a,b;
	complex operator + (const complex &t) const {
		return (complex){a+t.a,b+t.b};
	}
	complex operator - (const complex &t) const {
		return (complex){a-t.a,b-t.b};
	}
	complex operator * (const complex &t) const {
		return (complex){a*t.a-b*t.b,a*t.b+b*t.a};
	}
} a[maxn],b[maxn];

inline void fft(complex *t,int inv) {
	for(int i=0;i<p;i++) {
		if(i<rev[i]) swap(t[i],t[rev[i]]);
	}
	for(int mid=1;mid<p;mid<<=1) {
		complex unit={cos(pi/mid),inv*sin(pi/mid)};
		for(int l=0;l<p;l+=mid<<1) {
			complex now={1,0};
			for(int i=0;i<mid;i++) {
				complex x=t[l+i],y=now*t[l+mid+i];
				t[l+i]=x+y;
				t[l+mid+i]=x-y;
				now=now*unit;
			}
		}
	}
}

int main() {
	int n=read(),m=read();
	for(int i=0;i<=n;i++) a[i].a=read();
	for(int i=0;i<=m;i++) b[i].a=read();
	int sum=n+m;
	while(p<=sum) {p<<=1,bit++;}			//下标从0~p-1,保证为2的整数次幂
	for(int i=1;i<p;i++) {
		rev[i]=((rev[i>>1]>>1)|((i&1)<<(bit-1)));
	}
	fft(a,1),fft(b,1);
	for(int i=0;i<p;i++) a[i]=a[i]*b[i];
	fft(a,-1);
	for(int i=0;i<=sum;i++) printf("%d ",(int)(a[i].a/p+0.5));
	putchar('\n');
	return 0;
}

// by pycr

\(NTT:\)

#include <iostream>
#include <cstdio>

using namespace std;

const int maxn=4e6+10,mod=998244353,g=3,gn=332748118;//3在998244353下的逆元

int p=1,bit,inver;
int a[maxn],b[maxn],rev[maxn];

inline int power(long long a,int x) {
	long long ans=1;
	while(x) {
		if(x&1) ans=(ans*a)%mod;
		a=(a*a)%mod;
		x>>=1;
	}
	return ans;
}

inline void ntt(int *t,int inv) {
	for(int i=0;i<p;i++) {
		if(i<rev[i]) swap(t[i],t[rev[i]]);
	}
	for(int mid=1;mid<p;mid<<=1) {
		int unit=power(inv==1 ? g : gn,(mod-1)/(mid<<1));
		int d=mid<<1;
		for(int l=0;l<p;l+=d) {
			int now=1;
			for(int i=0;i<mid;i++) {
				int x=t[l+i],y=(long long)now*t[l+mid+i]%mod;
				t[l+i]=(x+y)%mod;
				t[l+mid+i]=(x-y+mod)%mod;
				now=(long long)now*unit%mod;
			}
		}
	}
	if(inv==-1) for(int i=0;i<p;i++) {
		a[i]=(long long)a[i]*inver%mod;
	}
}

signed main() {
	int n=read(),m=read();
	for(int i=0;i<=n;i++) a[i]=read();
	for(int i=0;i<=m;i++) b[i]=read();
	int sum=n+m;
	while(p<=sum) {p<<=1,bit++;}
	inver=power(p,mod-2);
	for(int i=1;i<p;i++) {
		rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
	}
	ntt(a,1),ntt(b,1);
	for(int i=0;i<p;i++) a[i]=(long long)a[i]*b[i]%mod;
	ntt(a,-1);
	for(int i=0;i<=sum;i++) printf("%d ",a[i]);
	return 0;
}

// by pycr

一些值得注意的点

这些都是我用无数次惨痛的教训和实验换来的

  • \(FFT\&NTT\) 在运算时取的点数一定要大于等于两个多项式相乘之后的项数,不然是无法得到最终的多项式的。用通俗的话来讲就是把 \(FFT\&NTT\) 乘爆了。
  • \(FFT\&NTT\) 是满足在模意义下的多项式乘法的性质的,即在 \(x^n\) 即以上的项的不会影响最后的结果,虽然位逆序置换会交换系数,但是 \(IDFT\) 之后会发现其实没有影响。
  • 尽管是这样,但是还是仍要保证 \(x^n\) 之后的系数尽量为 \(0\) ,因为第一点,如果后面有系数的话就有可能把 \(FFT\&NTT\) 乘爆了……在多项式的总结上会细讲。

——2021年2月7日

posted @ 2021-02-07 19:56  pycr  阅读(326)  评论(0编辑  收藏  举报