FFT & FNT 简要整理

几周前搞了搞……有点时间简要整理一下,诸多不足之处还请指出。
 
有哪些需要理解的地方?

  • 点值表示:对于多项式 \(A(x)\),把 \(n\) 个不同的 \(x\) 代入,会得出 \(n\) 个不同的 \(y\),在坐标系内就是 \(n\) 个不同的点,那么这 \(n\) 个点唯一确定该多项式
  • 为什么引入单位根 \(\omega\) 作为变量 \(x\):若代入一些 \(x\) ,使每个 \(x\) 的若干次方等于 \(1\),就不用做全部的次方运算了
  • 单位根的性质:于是可以分治实现 \(FFT\),复杂度降至 \(O(n log n)\)
  • 共轭复数:复数 \(z=a+bi\) 的共轭复数为 \(a−bi\)(虚部取反)。一个多项式在分治的过程中乘上单位根的共轭复数,分治完的每一项除以 \(n\) 即为原多项式的每一项系数
  • 初始化序列反转二进制数 : \(r[i] = (r[i >> 1] >> 1) | ((i \& 1) << (l - 1))\)
  • 利用二进制序列优化 \(FFT\):每个位置分治后的最终位置为其二进制翻转后得到的位置

 
算法的步骤?

  • 输入,将向量长度转化成二的幂次
  • 初始化序列反转二进制数 : \(r[i] = (r[i >> 1] >> 1) | ((i \& 1) << (l - 1))\)
  • 利用 \(FFT\) 将多项式转化为点值表示,如下:
  • 利用处理出的二进制求出待处理序列
  • 分治,每次找序列的一半,利用单位根计算(乘上 \(k\) 次幂),后对称到另一半
  • 将点值表示相乘,利用 \(IFFT\) 将其转换回系数表达式,输出
  • 多项式在分治的过程中乘上单位根的共轭复数,分治完的每一项除 \(n\) 即为原多项式的每一项系数
     
    \(NTT\) 其实就是在模意义下做了些许改变,十分类似。
  • \(NTT\) 模数:原根为 \(3\),有 \(469762049\)\(998244353\)\(1004535809\)
  • 原根在此特殊意义下代替了单位根,除法相应地变成逆元的乘法运算
  • 在模数并不是 \(NTT\) 模数时,采用 \(MTT\) 或三模数 \(NTT\) + \(CRT\)

 
\(FFT\)\(FNT\) 板子?

// Fast_fourier_transform 
#include <cmath>
#include <cstdio>
#include <cctype>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;

const double pi = acos(-1.0);
const int maxn = 5000000 + 10;
int n, m, len, limit, rado, r[maxn];

struct Complex {
  double x, y;
  Complex(double tx = 0, double ty = 0) { x = tx, y = ty; }
} f_1[maxn], f_2[maxn];

inline Complex operator + (const Complex &a, const Complex &b) { return Complex(a.x + b.x, a.y + b.y); }
inline Complex operator - (const Complex &a, const Complex &b) { return Complex(a.x - b.x, a.y - b.y); }
inline Complex operator * (const Complex &a, const Complex &b) { return Complex(a.x * b.x - a.y * b.y, a.x * b.y + b.x * a.y); }

inline int read() {
  register char ch = 0; register int w = 0, x = 0;
  while( !isdigit(ch) ) w |= (ch == '-'), ch = getchar();
  while( isdigit(ch) ) x = (x * 10) + (ch ^ 48), ch = getchar();
  return w ? -x : x;
}

inline void Fast_fourior_transform(Complex *a, int type) {
  for(int i = 0; i < limit; ++i) if( i < r[i] ) swap(a[i], a[r[i]]);
  for(int mid = 1; mid < limit; mid = mid << 1) {
    Complex Base_w = Complex(cos(pi / mid), type * sin(pi / mid));
    for(int len = mid << 1, l = 0; l < limit; l = l + len) {
      Complex w = Complex(1, 0);
      for(int k = 0; k < mid; ++k, w = w * Base_w) {
        Complex tmp_1 = a[l + k], tmp_2 = w * a[l + mid + k];
        a[l + k] = tmp_1 + tmp_2, a[l + mid + k] = tmp_1 - tmp_2;
      }
    }
  }
}

int main(int argc, char const *argv[])
{
  freopen("..\\nanjolno.in", "r", stdin);
  freopen("..\\nanjolno.out", "w", stdout);

  scanf("%d%d", &n, &m);
  for(int i = 0; i <= n; ++i) f_1[i].x = read();
  for(int i = 0; i <= m; ++i) f_2[i].x = read();
  len = n + m, limit = 1, rado = 0;
  while( limit <= len ) limit = limit << 1, ++rado;
  for(int i = 0; i < limit; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (rado - 1));
  Fast_fourior_transform(f_1, 1);
  Fast_fourior_transform(f_2, 1);
  for(int i = 0; i < limit; ++i) f_1[i] = f_1[i] * f_2[i];
  Fast_fourior_transform(f_1, -1);
  for(int i = 0; i <= len; ++i) printf("%d\n", (int)(f_1[i].x / limit + 0.1));

  fclose(stdin), fclose(stdout);
  return 0;
}
// Fast_number_theory_transform 
#include <cmath>
#include <cstdio>
#include <cctype>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;

const int mod = 998244353;
const int maxn = 5000000 + 10;
int n, m, len, limit, rado, f_1[maxn], f_2[maxn], r[maxn];

inline int read() {
  register char ch = 0; register int w = 0, x = 0;
  while( !isdigit(ch) ) w |= (ch == '-'), ch = getchar();
  while( isdigit(ch) ) x = (x * 10) + (ch ^ 48), ch = getchar();
  return w ? -x : x;
}

inline int Fast_pow(int a, int p) {
  long long x = a, ans = 1ll;
  for( ; p; x = x * x % mod, p = p >> 1) if( p & 1 ) ans = x * ans % mod;
  return (int)ans;
}

inline void Fast_numbertheory_transform(int *a, int limit, int type) {
  int rado = bit[limit];
  for(int i = 0; i < limit; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (rado - 1));
  for(int i = 0; i < limit; ++i) if( i < r[i] ) swap(a[i], a[r[i]]);
  for(int mid = 1; mid < limit; mid = mid << 1) {
    int Base_p = Fast_pow(3ll, (mod - 1) / (mid << 1));
    if( type == -1 ) Base_p = Fast_pow(Base_p, mod - 2);
    for(int l = 0, length = mid << 1; l < limit; l = l + length) {
      for(int k = 0, p = 1; k < mid; ++k, p = 1ll * p * Base_p % mod) {
        int x = a[l + k], y = 1ll * p * a[l + mid + k] % mod;
        a[l + k] = (x + y) % mod, a[l + mid + k] = (x - y + mod) % mod;
      }
    }
  }
  if( type == -1 ) for(int i = 0; i < limit; ++i) a[i] = 1ll * a[i] * Fast_pow(limit, mod - 2) % mod;
}

int main(int argc, char const *argv[])
{
  freopen("..\\nanjolno.in", "r", stdin);
  freopen("..\\nanjolno.out", "w", stdout);

  scanf("%d%d", &n, &m);
  for(int i = 0; i <= n; ++i) f_1[i] = read();
  for(int i = 0; i <= m; ++i) f_2[i] = read();
  len = n + m, limit = 1, rado = 0;
  while( limit <= len ) limit = limit << 1, ++rado;
  for(int i = 0; i < limit; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (rado - 1));
  Fast_numbertheory_transform(f_1, 1);
  Fast_numbertheory_transform(f_2, 1);
  for(int i = 0; i < limit; ++i) f_1[i] = 1ll * f_1[i] * f_2[i] % mod;
  Fast_numbertheory_transform(f_1, -1);
  for(int i = 0; i <= len; ++i) printf("%d ", f_1[i]);

  fclose(stdin), fclose(stdout);
  return 0;
}

 
                  流萤断续光,一明一灭一尺间,寂寞何以堪。

posted @ 2019-02-09 20:53  南條雪绘  阅读(544)  评论(0编辑  收藏  举报