FFT(快速傅里叶变换)

题目链接

3122. 多项式乘法P3803 【模板】多项式乘法(FFT)

3122. 多项式乘法

题目描述

给定一个 \(n\) 次多项式 \(F(x)=a_0+a_1x+a_2x_2+…+a_nx_n\)

以及一个 \(m\) 次多项式 \(G(x)=b_0+b_1x+b_2x_2+…+b_mx_m\)

已知 \(H(x)=F(x)⋅G(x)=c_0+c_1x+c_2x_2+…+c_{n+m}x_{n+m}\)

请你计算并输出 \(c_0,c_1,…,c_{n+m}\)

输入格式

第一行包含两个整数 \(n,m\)

第二行包含 \(n+1\) 个整数 \(a_0,a_1,…,a_n\)

第三行包含 \(m+1\) 个整数 \(b_0,b_1,…,b_m\)

输出格式

共一行,依次输出 \(c_0,c_1,…,c_{n+m}\)

数据范围

\(1≤n,m≤10^5\),
\(0≤a_i,b_i≤9\)

输入样例:

1 2
1 3
2 2 1

输出样例:

2 8 7 3

解题思路

fft

一个 \(n-1\) 次多项式 \(f(x)=a_0+a_1\times x^1+a_2\times x^2+\dots + a_{n-1}\times x^{n-1}\)可以由 \(n\)\((x_i,f(x_i))\) 点唯一表示

证明:
即求解 \(n\) 元一次方程的 \(a_i\),将其系数用行列式表示出来:

\[\left|\begin{array}{cccc} 1 & x_{1} & \cdots & x_{1}^{n-1} \\ 1 & x_{2} & \cdots & x_{2}^{n-1}\\ \vdots & \vdots & & \vdots \\ 1 & x_{n} & \cdots & x_{n}^{n-1} \end{array}\right| \]

即该系数行列式即为范德蒙行列式的转置,故其值为:\(\prod_{1 \leq j<i \leq n}\left(x_{i}-x_{j}\right)\),而 \(x_i\neq x_j\),故其值不为 \(0\),而一个一次 \(n\) 元方程组有唯一解则其系数行列式不为 \(0\),故其解唯一,所以一个 \(n-1\) 次多项式可以由 \(n\) 个点唯一表示

一般 \(fft\) 都是用来求解多项式的卷积(乘积),暴力做法的复杂度为 \(O(nm)\),而 \(fft\) 利用上述性质,即将其转化为点表示法,利用点表示法求出结果的点表示法,最后再将点表示法转化为系数表示法,例如:\(f(x)=g(x)\times h(x)\),将 \(g(x_i)\)\(h(x_i)\) 转化为点表示法,\(f(x)\)\(n+m\) 项,\(g(x)\)\(h(x)\) 不足的补零,将对应的共同的 \(x_i\)\(g(x_i)\)\(h(x_i)\) 相乘,得 \(f(x)\) 的点表示法,最后再将点表示法转化为系数表示法,而点表示法和系数表示法的相互转化的复杂度为 \(O(nlogn)\),即 \(fft\) 的复杂度为 \(O((n+m)\times log(n+m))\)

所以 \(fft\) 的关键在于点表示法和系数表示法的相互转化
系数表示法转化为点表示法
其中点当然也可以是复数,故这 \(n\) 个点可以为单位圆上辐角均分的 \(n\) 个点,用 \(w_n^k\) 表示为第 \(k\) 个点,其表示的复数为 \((cos(2\pi\times \frac{k}{n}),sin(2\pi\times \frac{k}{n}))\),另外要用到的复数乘法的一些性质:乘法后的复数模长为两复数模长相乘,辐角为两复数辐角相加,则有 \(w_n^{k+\frac{n}{2}}=-w_n^k\)\(w_{2n}^{2k}=w_n^k\)
设一个 \(n-1\)\(n\)\(2\) 次幂)次多项式 \(A(x)=a_{0}+a_{1} x+a_{2} x^{2} +\cdots+a_{n-2} x^{n-2}+a_{n-1} x^{n-1}\)
设:

\[\begin{aligned} &A_{1}(x)=a_{0}+a_{2} x^{2}+\cdots+a_{n-2} x^{\frac{n}{2}-1} \\ &A_{2}(x)=a_{1}+a_{3} x^{1}+\cdots+a_{n-1} x^{\frac{n}{2}-1} \end{aligned} \]

则有 \(A(x)=A_1(x^2)+xA_2(x^2)\)
假设 \(0 \leq k<\frac{n}{2}\) ,将 \(x=\omega_{n}^{k}\) 代入,得

\[A\left(\omega_{n}^{k}\right)=A_{1}\left(\omega_{n}^{2 k}\right)+\omega_{n}^{k} A_{2}\left(\omega_{n}^{2 k}\right) \]

\(x=\omega_{n}^{k+\frac{n}{2}}\) 代入,得

\[\begin{aligned} A\left(\omega_{n}^{k+\frac{n}{2}}\right) &=A_{1}\left(\omega_{n}^{2 k+n}\right)+\omega_{n}^{k+\frac{n}{2}} A_{2}\left(\omega_{n}^{2 k+n}\right) \\ &=A_{1}\left(\omega_{n}^{2 k}\right)+\omega_{n}^{k+\frac{n}{2}} A_{2}\left(\omega_{n}^{2 k}\right) \\ &=A_{1}\left(\omega_{n}^{2 k}\right)-\omega_{n}^{k} A_{2}\left(\omega_{n}^{2 k}\right) \end{aligned} \]

即如果知道一半的 \(A(\omega_{n}^{k})\),则可以知道另一半的 \(A(\omega_{n}^{k+\frac{n}{2}})\),另外每次都能将原问题的求解规模缩小一半,故其时间复杂度为 \(O(nlogn)\)
由于 \(fft\) 的递归写法常数巨大,故一般采用迭代写法,例如:

\[\underline{a_0\ a_1\ a_2\ a_3\ a_4\ a_5\ a_6\ a_7}\\ \underline{a_0\ a_2\ a_4\ a_6}\ \underline{\ a_1\ a_3\ a_5\ a_7}\\ \underline{a_0\ a_4}\ \underline{\ a_2\ a_6}\ \underline{\ a_1\ a_5}\ \underline{\ a_3\ a_7}\\ \underline{a_0}\ \underline{a_4}\ \underline{a_2}\ \underline{a_6}\ \underline{a_1}\ \underline{a_5}\ \underline{a_3}\ \underline{a_7}\\ \]

其中最上层的 \(a_i\) 表示按多项式顺序的值,最下层表示初始时的 \(a_i\),例如 \(a_0\) 由下层的 \(a_0\)\(a_1\) 计算得到,其他同理。迭代需要最底层的数,\(\color{red}{最底层的数的下标和原数的下标之间有什么关系?}\)下标的二进制之间互为翻转关系(蝴蝶变换)。简略证明:如果一个数为奇数,即二进制下最低位为 \(1\),则其一定在另外一半中出现,而另外一半最高位都为 \(1\),即最高位和最低位满足翻转关系,同理不考虑最高位和最低位的情况下其他位也满足翻转关系。得到最底层,往上计算即得点表示法,具体计算:设置一个变量 \(mid\),表示当前区间数需要计算的次数,由于每次只需计算一半,\(2\times mid\) 即当前区间数,\(i\) 为每个区间开始的位置,\(j\) 为每个区间的位置变量,\(i+j\) 即为每个区间中多项式对应的值,据此递推计算即可

点表示法转化为系数表示法

设一个 \(n-1\)\(n\)\(2\) 次幂)次多项式 \(A(x)=a_{0}+a_{1} x+a_{2} x^{2} +\cdots+a_{n-2} x^{n-2}+a_{n-1} x^{n-1}\),由迭代过程得到 \(n\) 个多项式的值 \(y_i\),则 \(a_k=\frac{\sum_{i=0}^{n-1}y_i(\omega_n^{-k})^i}{n}\)

证明:\(\sum_{i=0}^{n-1}y_i(\omega_n^{-k})^i=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}a_j(\omega_n^i)^j(\omega_n^{-k})^i=\sum_{j=0}^{n-1}a_j\sum_{i=0}^{n-1}(\omega_n^{j-k})^i\)
\(s(x)=1+x+x^2+\dots +x^{n-1}\),则如果 \(k\neq 0\)\(s(\omega_n^k)=1+\omega_n^k +\omega_n^{2k}+\dots +\omega_n^{(n-1)k}\),而 \(\omega_n^ks(\omega_n^k)=\omega_n^k+\omega_n^{2k}+\dots +\omega_n^n=s(\omega_n^k)\),而 \(\omega_n^k\neq 0\),则 \(s(\omega_n^k)=0\),如果 \(k=0\),则 \(s(\omega_n^k)=n\)。故 \(\sum_{j=0}^{n-1}a_j\sum_{i=0}^{n-1}(\omega_n^{j-k})^i\) 当且仅当 \(j=k\)\(\sum_{i=0}^{n-1}(\omega_n^{j-k})^i\neq 0\),且其等于 \(n\),故 \(\sum_{j=0}^{n-1}a_j\sum_{i=0}^{n-1}(\omega_n^{j-k})^i=na_k\),得证
\(\color{red}{知道这个有什么用处呢?}\)\(g(x)=\sum_{i=0}^{n-1}y_ix^i\),则 \(a_k=\frac{g(\omega_n^{-k})}{n}\),即可以利用转化为点表示法的过程求解系数表示法

  • 时间复杂度:\(O((n+m)\times log(n+m))\)

代码

// Problem: 多项式乘法
// Contest: AcWing
// URL: https://www.acwing.com/problem/content/3125/
// Memory Limit: 64 MB
// Time Limit: 1000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

// %%%Skyqwq
#include <bits/stdc++.h>
 
//#define int long long
#define help {cin.tie(NULL); cout.tie(NULL);}
#define pb push_back
#define fi first
#define se second
#define mkp make_pair
using namespace std;
 
typedef long long LL;
typedef pair<int, int> PII;
typedef pair<LL, LL> PLL;
 
template <typename T> bool chkMax(T &x, T y) { return (y > x) ? x = y, 1 : 0; }
template <typename T> bool chkMin(T &x, T y) { return (y < x) ? x = y, 1 : 0; }
 
template <typename T> void inline read(T &x) {
    int f = 1; x = 0; char s = getchar();
    while (s < '0' || s > '9') { if (s == '-') f = -1; s = getchar(); }
    while (s <= '9' && s >= '0') x = x * 10 + (s ^ 48), s = getchar();
    x *= f;
}

const double pi=acos(-1);
const int N=3e5+5;
int n,m,bit,tot,rev[N];
struct cp
{
	double x,y;
	cp operator+(const cp &o)const
	{
		return {x+o.x,y+o.y};
	}
	cp operator-(const cp &o)const
	{
		return {x-o.x,y-o.y};
	}
	cp operator*(const cp &o)const
	{
		return {x*o.x-y*o.y,x*o.y+y*o.x};
	}
}a[N],b[N];
void fft(cp a[],int inv)
{
	for(int i=0;i<tot;i++)
		if(i<rev[i])swap(a[i],a[rev[i]]);
	for(int mid=1;mid<tot;mid<<=1)
	{
		cp w1={cos(pi/mid),inv*sin(pi/mid)};
		for(int i=0;i<tot;i+=mid*2)
		{
			cp wk={1,0};
			for(int j=0;j<mid;j++,wk=wk*w1)
			{
				cp x=a[i+j],y=a[i+j+mid];
				a[i+j]=x+wk*y,a[i+j+mid]=x-wk*y;
			}
		}
	}
}
int main()
{
    help;
    cin>>n>>m;
    for(int i=0;i<=n;i++)cin>>a[i].x;
    for(int i=0;i<=m;i++)cin>>b[i].x;
    while((1<<bit)<n+m+1)bit++;
    tot=1<<bit;
    for(int i=0;i<tot;i++)rev[i]=rev[i>>1]>>1|((i&1)<<(bit-1));
    fft(a,1),fft(b,1);
    for(int i=0;i<tot;i++)a[i]=a[i]*b[i];
    fft(a,-1);
    for(int i=0;i<=n+m;i++)cout<<int(a[i].x/tot+0.5)<<' ';
    return 0;
}
posted @ 2021-10-25 15:16  zyy2001  阅读(432)  评论(0编辑  收藏  举报