FFT(快速傅里叶变换)
题目链接
3122. 多项式乘法同P3803 【模板】多项式乘法(FFT)
3122. 多项式乘法
题目描述
给定一个 \(n\) 次多项式 \(F(x)=a_0+a_1x+a_2x_2+…+a_nx_n\)。
以及一个 \(m\) 次多项式 \(G(x)=b_0+b_1x+b_2x_2+…+b_mx_m\)。
已知 \(H(x)=F(x)⋅G(x)=c_0+c_1x+c_2x_2+…+c_{n+m}x_{n+m}\)。
请你计算并输出 \(c_0,c_1,…,c_{n+m}\)。
输入格式
第一行包含两个整数 \(n,m\)。
第二行包含 \(n+1\) 个整数 \(a_0,a_1,…,a_n\)。
第三行包含 \(m+1\) 个整数 \(b_0,b_1,…,b_m\)。
输出格式
共一行,依次输出 \(c_0,c_1,…,c_{n+m}\)。
数据范围
\(1≤n,m≤10^5\),
\(0≤a_i,b_i≤9\)
输入样例:
1 2
1 3
2 2 1
输出样例:
2 8 7 3
解题思路
fft
一个 \(n-1\) 次多项式 \(f(x)=a_0+a_1\times x^1+a_2\times x^2+\dots + a_{n-1}\times x^{n-1}\)可以由 \(n\) 个 \((x_i,f(x_i))\) 点唯一表示
证明:
即求解 \(n\) 元一次方程的 \(a_i\),将其系数用行列式表示出来:
即该系数行列式即为范德蒙行列式的转置,故其值为:\(\prod_{1 \leq j<i \leq n}\left(x_{i}-x_{j}\right)\),而 \(x_i\neq x_j\),故其值不为 \(0\),而一个一次 \(n\) 元方程组有唯一解则其系数行列式不为 \(0\),故其解唯一,所以一个 \(n-1\) 次多项式可以由 \(n\) 个点唯一表示
一般 \(fft\) 都是用来求解多项式的卷积(乘积),暴力做法的复杂度为 \(O(nm)\),而 \(fft\) 利用上述性质,即将其转化为点表示法,利用点表示法求出结果的点表示法,最后再将点表示法转化为系数表示法,例如:\(f(x)=g(x)\times h(x)\),将 \(g(x_i)\) 和 \(h(x_i)\) 转化为点表示法,\(f(x)\) 共 \(n+m\) 项,\(g(x)\) 和 \(h(x)\) 不足的补零,将对应的共同的 \(x_i\) 的 \(g(x_i)\) 和 \(h(x_i)\) 相乘,得 \(f(x)\) 的点表示法,最后再将点表示法转化为系数表示法,而点表示法和系数表示法的相互转化的复杂度为 \(O(nlogn)\),即 \(fft\) 的复杂度为 \(O((n+m)\times log(n+m))\)
所以 \(fft\) 的关键在于点表示法和系数表示法的相互转化
系数表示法转化为点表示法
其中点当然也可以是复数,故这 \(n\) 个点可以为单位圆上辐角均分的 \(n\) 个点,用 \(w_n^k\) 表示为第 \(k\) 个点,其表示的复数为 \((cos(2\pi\times \frac{k}{n}),sin(2\pi\times \frac{k}{n}))\),另外要用到的复数乘法的一些性质:乘法后的复数模长为两复数模长相乘,辐角为两复数辐角相加,则有 \(w_n^{k+\frac{n}{2}}=-w_n^k\),\(w_{2n}^{2k}=w_n^k\)
设一个 \(n-1\) (\(n\) 为 \(2\) 次幂)次多项式 \(A(x)=a_{0}+a_{1} x+a_{2} x^{2} +\cdots+a_{n-2} x^{n-2}+a_{n-1} x^{n-1}\)
设:
则有 \(A(x)=A_1(x^2)+xA_2(x^2)\)
假设 \(0 \leq k<\frac{n}{2}\) ,将 \(x=\omega_{n}^{k}\) 代入,得
将 \(x=\omega_{n}^{k+\frac{n}{2}}\) 代入,得
即如果知道一半的 \(A(\omega_{n}^{k})\),则可以知道另一半的 \(A(\omega_{n}^{k+\frac{n}{2}})\),另外每次都能将原问题的求解规模缩小一半,故其时间复杂度为 \(O(nlogn)\)
由于 \(fft\) 的递归写法常数巨大,故一般采用迭代写法,例如:
其中最上层的 \(a_i\) 表示按多项式顺序的值,最下层表示初始时的 \(a_i\),例如 \(a_0\) 由下层的 \(a_0\) 和 \(a_1\) 计算得到,其他同理。迭代需要最底层的数,\(\color{red}{最底层的数的下标和原数的下标之间有什么关系?}\)下标的二进制之间互为翻转关系(蝴蝶变换)。简略证明:如果一个数为奇数,即二进制下最低位为 \(1\),则其一定在另外一半中出现,而另外一半最高位都为 \(1\),即最高位和最低位满足翻转关系,同理不考虑最高位和最低位的情况下其他位也满足翻转关系。得到最底层,往上计算即得点表示法,具体计算:设置一个变量 \(mid\),表示当前区间数需要计算的次数,由于每次只需计算一半,\(2\times mid\) 即当前区间数,\(i\) 为每个区间开始的位置,\(j\) 为每个区间的位置变量,\(i+j\) 即为每个区间中多项式对应的值,据此递推计算即可
点表示法转化为系数表示法
设一个 \(n-1\) (\(n\) 为 \(2\) 次幂)次多项式 \(A(x)=a_{0}+a_{1} x+a_{2} x^{2} +\cdots+a_{n-2} x^{n-2}+a_{n-1} x^{n-1}\),由迭代过程得到 \(n\) 个多项式的值 \(y_i\),则 \(a_k=\frac{\sum_{i=0}^{n-1}y_i(\omega_n^{-k})^i}{n}\)
证明:\(\sum_{i=0}^{n-1}y_i(\omega_n^{-k})^i=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}a_j(\omega_n^i)^j(\omega_n^{-k})^i=\sum_{j=0}^{n-1}a_j\sum_{i=0}^{n-1}(\omega_n^{j-k})^i\)
设 \(s(x)=1+x+x^2+\dots +x^{n-1}\),则如果 \(k\neq 0\),\(s(\omega_n^k)=1+\omega_n^k +\omega_n^{2k}+\dots +\omega_n^{(n-1)k}\),而 \(\omega_n^ks(\omega_n^k)=\omega_n^k+\omega_n^{2k}+\dots +\omega_n^n=s(\omega_n^k)\),而 \(\omega_n^k\neq 0\),则 \(s(\omega_n^k)=0\),如果 \(k=0\),则 \(s(\omega_n^k)=n\)。故 \(\sum_{j=0}^{n-1}a_j\sum_{i=0}^{n-1}(\omega_n^{j-k})^i\) 当且仅当 \(j=k\) 时 \(\sum_{i=0}^{n-1}(\omega_n^{j-k})^i\neq 0\),且其等于 \(n\),故 \(\sum_{j=0}^{n-1}a_j\sum_{i=0}^{n-1}(\omega_n^{j-k})^i=na_k\),得证
\(\color{red}{知道这个有什么用处呢?}\) 设 \(g(x)=\sum_{i=0}^{n-1}y_ix^i\),则 \(a_k=\frac{g(\omega_n^{-k})}{n}\),即可以利用转化为点表示法的过程求解系数表示法
- 时间复杂度:\(O((n+m)\times log(n+m))\)
代码
// Problem: 多项式乘法
// Contest: AcWing
// URL: https://www.acwing.com/problem/content/3125/
// Memory Limit: 64 MB
// Time Limit: 1000 ms
//
// Powered by CP Editor (https://cpeditor.org)
// %%%Skyqwq
#include <bits/stdc++.h>
//#define int long long
#define help {cin.tie(NULL); cout.tie(NULL);}
#define pb push_back
#define fi first
#define se second
#define mkp make_pair
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
typedef pair<LL, LL> PLL;
template <typename T> bool chkMax(T &x, T y) { return (y > x) ? x = y, 1 : 0; }
template <typename T> bool chkMin(T &x, T y) { return (y < x) ? x = y, 1 : 0; }
template <typename T> void inline read(T &x) {
int f = 1; x = 0; char s = getchar();
while (s < '0' || s > '9') { if (s == '-') f = -1; s = getchar(); }
while (s <= '9' && s >= '0') x = x * 10 + (s ^ 48), s = getchar();
x *= f;
}
const double pi=acos(-1);
const int N=3e5+5;
int n,m,bit,tot,rev[N];
struct cp
{
double x,y;
cp operator+(const cp &o)const
{
return {x+o.x,y+o.y};
}
cp operator-(const cp &o)const
{
return {x-o.x,y-o.y};
}
cp operator*(const cp &o)const
{
return {x*o.x-y*o.y,x*o.y+y*o.x};
}
}a[N],b[N];
void fft(cp a[],int inv)
{
for(int i=0;i<tot;i++)
if(i<rev[i])swap(a[i],a[rev[i]]);
for(int mid=1;mid<tot;mid<<=1)
{
cp w1={cos(pi/mid),inv*sin(pi/mid)};
for(int i=0;i<tot;i+=mid*2)
{
cp wk={1,0};
for(int j=0;j<mid;j++,wk=wk*w1)
{
cp x=a[i+j],y=a[i+j+mid];
a[i+j]=x+wk*y,a[i+j+mid]=x-wk*y;
}
}
}
}
int main()
{
help;
cin>>n>>m;
for(int i=0;i<=n;i++)cin>>a[i].x;
for(int i=0;i<=m;i++)cin>>b[i].x;
while((1<<bit)<n+m+1)bit++;
tot=1<<bit;
for(int i=0;i<tot;i++)rev[i]=rev[i>>1]>>1|((i&1)<<(bit-1));
fft(a,1),fft(b,1);
for(int i=0;i<tot;i++)a[i]=a[i]*b[i];
fft(a,-1);
for(int i=0;i<=n+m;i++)cout<<int(a[i].x/tot+0.5)<<' ';
return 0;
}