逸一时,误一世!|

MistZero

园龄:5年5个月粉丝:8关注:3

关于 FFT 和 NTT 等基础多项式科技

最基础的 FFT 可以 O(NlogN) 解决多项式乘法。
朴素的就是 O(N2) 比较低端。

FFT 先是把一个用系数表示的多项式转化为点值表示。
然后通过这样的转换搞科技。

首先一个 n1n 项多项式 f(x)=i=0n1aixi, 默认 a0,...,n1
或者可以表示为 f(x)={a0,a1,...,an1}

然后引入一种高级的东西 - 点值表示法。
把一个多项式丢到平面直角坐标系里面,看成一个函数。
然后不同的 x 代进去 f(x) 都会得到不同的 (x,f(x))

然后规定一下 f(x){(x0,f(x0)),...,(xn1,f(xn1))} 表示,就是点值表示法。

然后如果你要暴力搞的话肯定系数转点值和点值转系数都 O(N2)
两种朴素算法分别叫做 DFT 和 IDFT,离散傅里叶变换和离散傅里叶逆变换。

首先可以直接将任何一个复数表示到复平面直角坐标系上面变成一个点。
横坐标是实部,单位为 1;纵坐标为虚部,单位为 i=1
复数运算不多赘述了。


多项式转点值可以找到一个地方突破,就是代入一组特殊的 x 使次方运算减少。
那么就可以直接钦定一个复平面坐标系上以原点为圆心划一个半径为 1 的圆。
例如 (0,i)(1,0)(0,i)(1,0) 这样一个圆。
规定一下方便的叫法,这是个单位圆。
那么所有点经过若干次方都可以变成 1

方便起见可以把它分成 n 份代表 n 项式。
(1,0) 开始逆时针编号,其实随便怎么标都可以,你开心就好。
ωn1 代表 n 次单位根。
然后设第 i 个点的复数值为 ωni

然后丢个简单的公式,就是模长相乘,极角相加。
大概就是 (a,θ1)×(b,θ2)=(ab,θ1+θ2),其中横坐标为实,纵坐标为虚。

这里就可以稍稍拓展一下知道 ωni+j=ωni×ωnj
然后由定义显然有 (ωni)j=ωnij

所以由这个公式知道 (ωn1)i=ωni
那么每一个 ω 就可以通过三角函数乘上占的比例求出来。

ωnk=coskn2π+isinkn2π

(这里下标不用 i 了是因为怕跟虚数单位 i 混淆)

那么这一大堆 ω 就可以当 x0...n1

推 FFT 需要用到两个性质,就是可以拿来搞分治的两个性质。

就是 ωnk=coskn2π+isinkn2π=cos2k2n2π+isin2k2n2π=ω2n2k

以及 ωni=ωni+n2,就是类似一个多边形对角线的性质。

因为两个点关于原点对称所以有这个定理,很显然。

那么下面就是整活的推柿子了。

f(n)=i=0n1aixi=a0+a1x+a2x2+...+an1xn1

默认 n 为奇数,方便计算。

那么将 f(n) 按照下标奇偶性分成两半

=(a0+a2x2+...+an2xn2)+(a1x+a3x3+...+an1xn1)
=(a0+a2x2+...+an2xn2)+x(a1+a3x2+...+an1xn2)

然后因为两边非常相似,所以设

f1(n)=a0+a2x+a4x2+...+an2xn22
f2(n)=a1+a3x+a5x2+...+an1xn22

很明显可以把 f1f2 代进去。

原式 =f1(n2)+n×f2(n2)

然后可以开始 DFS / 迭代这样子做,前半段设有 i<n2,那么把 x=ωni 代入原式

f(ωni)=f1((ωni)2)+ωni×f2((ωni)2)

然后因为是平方,所以 2 可以乘到上面去,所以有

原式 =f1(ωn2i)+ωni×f2(ωn2i)=f1(ωn2i)+ωni×f2(ωn2i)

后面那个成立是因为 ωni=ω2n2i

因为分治特别好玩,所以后面那段也要分治。

同理可以推推推,设 t=n2(手敲累了)。

f(ωni+t)=f1(ωn2i+n)+ωni+t×f2(ωn2i+n)

因为显然有 ωn2i+n=ωnn×ωn2i
所以直接拆开,因为 ωnn=ωn0=0,跑路。

原式 =f1(ωn2i)+ωni+t×f2(ωn2i)

做到这里发现推不动了,直到看到 ωni=ωni+n2 这个玩意我们发现可以代进去。

原式 =f1(ωn2i)ωni×f2(ωn2i)

然后因为 ωni=ω2n2i 有原式 =f1(ωti)ωni×f2(ωti)

最后发现后面那一项,两半分别为相反数。
所以我们知道 f(ωni)f(ωni+t) 只有后面那个地方不同。
换句话说知道了 f(ωni) 就算出了 f(ωni+t)

那么就可以这样迭代下去求了。

时间复杂度 O(NlogN)

#include <math.h>
#include <vector>
#include <stdio.h>
#include <algorithm>
#define rg register

namespace IO {

  const int MAX_LEN = 5e7;
  static char bufin[MAX_LEN], *p1 = bufin;
  #define gc() (*p1++)
  #define isd(ch) (ch>47&&ch<58)

  template <typename T>
  inline static void read(T &ret) {
    ret = 0; rg T f = 1; char ch = gc();
    while (!isd(ch) && ch^'-') ch = gc();
    if (ch=='-') f = -1, ch = gc();
    while (isd(ch)) ret = ret*10+(ch^48), ch = gc();
    return ;
  }

  int cnt, tp;
  static char bf[20], buf[MAX_LEN];

  template <typename T>
  inline static void print(T Num) {
    if (!Num) { buf[cnt++] = '0', buf[cnt++] = ' '; return ; }
    if (Num<0) buf[cnt++] = '-', Num = -Num;
    tp = 0; while (Num)
      bf[tp++] = Num%10^48, Num /= 10;
    while (tp) buf[cnt++] = bf[--tp];
    buf[cnt++] = ' '; return ;
  }

}

namespace Math {

  const int MAX_SIZE = 1e6 + 10;
  int Md, Range, fac[MAX_SIZE];

  template <typename T>
  inline static T qpow(T bas, T pw) {
    T mult = 1;
    while (pw) {
      if (pw&1) mult = mult * bas % Md;
      bas = bas * bas % Md, pw >>= 1;
    } return mult;
  }

  template <typename T>
  inline static T inv(T x) { return qpow(x, Md-2); }
  template <typename T>
  inline static void init() {
    fac[0] = fac[1] = 1;
    for (rg int i=2; i<=Range; ++i)
      fac[i] = 1LL * fac[i-1] * i % Md;
    return ;
  }

}

using namespace IO;
using namespace Math;

const double pi = acos(-1.0);
const int N = 3e6 + 10;

struct cplx {
  double real, im;
  cplx (double real, double im):
    real(real), im(im) {}
  cplx() {}
} x[N], y[N];

// 手写复数太丑了略过

int n, m, Log, Lim = 1;
int status[N];

inline void  FFT (cplx *x, int typ) {
  for (int i=0; i<Lim; ++i)
    if (i<status[i]) Swap(x[i], x[status[i]]);
  for (int mid=1; mid<Lim; mid<<=1) {
    cplx omega(cos(pi/mid), typ*sin(pi/mid));
    for (int rig=mid<<1, pos=0; pos<Lim; pos+=rig) {
      cplx pw(1, 0);
      for (int k=0; k<mid; ++k, pw=pw*omega) {
        cplx buf1 = x[pos+k], buf2 = pw * x[pos+k+mid];
        x[pos+k] = buf1 + buf2, x[pos+k+mid] = buf1 - buf2;
      }
    }
  } return ;
}

int main() {
  fread(IO::bufin, 1, 50000000, stdin);
  IO::read(n), IO::read(m);
  for (int i=0; i<=n; ++i) IO::read(x[i].real);
  for (int i=0; i<=m; ++i) IO::read(y[i].real);
  while (Lim<=(n+m)) ++Log, Lim<<=1;
  for (int i=0; i<Lim; ++i)
    status[i] = (status[i>>1]>>1) | ((i&1)<<(Log-1));
   FFT (x, 1),  FFT (y, 1);
  for (int i=0; i<=Lim; ++i) x[i] = x[i] * y[i];
   FFT (x, -1); for (int i=0; i<=n+m; ++i)
    IO::print((int)(x[i].real/Lim+.5));
  fwrite(IO::buf, 1, IO::cnt, stdout); return 0;
}

NTT 主要思想就是把毒瘤的 FFT 复数给换掉
换成一种可以代替复数的、并且能够解决精度问题的东西,原根。

介绍几个专有名词。

:如果 gcd(a,p)=1 并且 p>1
那么对于 nmin 满足 an1 (mod p),我们称 nap 的阶,记作 δp(a)

原根:设 pN, aZ(不会打 LATEX,轻喷)。
如果 δp(a)=ϕ(p),则称 a 为模 p 的一个原根。

原根存在的充要条件是,原根 d=2,4,xy,2xy(其中 x 为奇素数,y1)。
每一个正整数 p 都有 ϕ(ϕ(p)) 个原根,素数也一样。

NTT 到这里基本上就出来了,也就是最重要的定理:

  • p 为素数且 gp 的原根,那么 gimodp 的结果两两不同。
    其中g(1,p), i(0,p)

这玩意儿可以代替原来的复数来进行 FFT,所以有了个新名字,NTT 。
FFT 里面不是用到了单位根的几条性质吗,恰好原根也满足这几个性质。
所以原根就可以拿来代替复数。

那么我们直接将 ωi 替换为 gp1imodp 即可。
p 的取值,取 998244353 非常好,原根为 3

大概求解任意一个质数 t 的原根,只需要把 t1 分解质因数
变成 t=i=1npiki 这样的乘积形式
然后如果 1in, gt1pi1 (mod t),那么 gp 原根。

#include <bits/stdc++.h>
#define int long long
using namespace std;

const int N = 3e6 + 10;
const int Mod = 998244353;
int n, m, Lim = 1, Log, status[N];
int a[N], b[N];

void read(int &ret) {
  ret = 0; char ch = getchar();
  while (!isdigit(ch)) ch = getchar();
  while (isdigit(ch)) {
    ret = (ret<<1) + (ret<<3) + (ch^48);
    ch = getchar();
  } return ;
}

int qpow(int bas, int pw) {
  int mul = 1;
  while (pw) {
    if (pw&1) mul = mul * bas % Mod;
    bas = bas * bas % Mod;
    pw >>= 1;
  } return mul;
}

void NTT(int *x, int typ) {
  for (int i=0; i<Lim; ++i)
    if (i<status[i]) swap(x[i], x[status[i]]);
  for (int mid=1; mid<Lim; mid<<=1) {
    int omega = qpow(typ==1? 3:332748118, (Mod-1)/(mid<<1));
    for (int pos=0; pos<Lim; pos+=(mid<<1)) {
      int pw = 1;
      for (int k=0; k<mid; ++k, pw = (pw*omega)%Mod) {
        int buf1 = x[pos+k], buf2 = pw * x[pos+k+mid] % Mod;
        x[pos+k] = (buf1 + buf2) % Mod;
        x[pos+k+mid] = ((buf1 - buf2) % Mod + Mod) % Mod;
      }
    }
  } return ;
}

signed main() {
  read(n), read(m);
  for (int i=0; i<=n; ++i) read(a[i]);
  for (int i=0; i<=m; ++i) read(b[i]);
  while (Lim<=(n+m)) ++Log, Lim<<=1;
  for (int i=0; i<Lim; ++i)
    status[i] = (status[i>>1]>>1) | ((i&1)<<(Log-1));
  NTT(a, 1), NTT(b, 1);
  for (int i=0; i<Lim; ++i) a[i] = (a[i] * b[i]) % Mod;
  NTT(a, -1); int inv = qpow(Lim, Mod-2);
  for (int i=0; i<=n+m; ++i)
    printf("%lld ", a[i] * inv % Mod);
  return 0;
}

本文作者:MistZero

本文链接:https://www.cnblogs.com/MistZero/p/Basic-FFT-and-NTT.html

版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。

posted @   MistZero  阅读(111)  评论(0编辑  收藏  举报
点击右上角即可分享
微信分享提示
评论
收藏
关注
推荐
深色
回顶
收起