$\text{FFT}$ 的预处理优化

之前向学习了一个 $\text{FFT}$ 的优化,但是像我这么弱的人每次打 $\text{FFT}$ 板子的时候都会忘记这个东西,在这里记一下。

​我们知道普通的 $\text{FFT}$ 会用到原根 $\omega_n^0,\omega_n^1\cdots\omega_n^{n-1}$ 然后这些东西会在枚举步长的时候通过 $\omega_n = e^{\frac{2\pi}{n}}$ 和 $e^{\theta i} = \cos \theta + i\sin \theta$ 这两个公式一次一次算出来。

​然而我们知道,调用三角函数是非常慢的,每次计算的时候,即使你是手写的 $\text{complex}$ 也会非常慢,这就使得这种 $\text{FFT}$ 的常数巨大无比。

​所以我们就预处理一下每次需要用到的 $\omega$ ,把每一种步长需要用到的 $\omega$ 扔到同一个数组 $W$ 里,有每种步长的 $\omega$ 连续。而因为 $\sum_{i=0}^{n} 2^i = 2^{i + 1} - 1$ ,所以每次需要访问步长为 $s$ 的 $\omega$ 时候只要访问 $W[s]$ 就可以了,将一个指针指向他,而后面的只要把指针一步一步往后移即可。

​这是 $\text{DFT}$ 的时候用的,但是我们知道 $\text{IDFT}$ 的时候用的 $\omega$ 和 $\text{DFT}$ 的时候是不一样的。

​然而我们不需要重新处理 $\text{IDFT}$ 用的 $\omega$ ,只需要把需要 $\text{FFT}$ 的 $A$ 从 $1$ 到 $n - 1$ 的值 $\text{reverse}$ 一下就行了。原理是本来 $\text{IDFT}$ 的时候需要把 $\omega$ 翻过来,但是那个有点麻烦,于是我们就把 $A$ 给翻过来就行了。由于 $\text{FFT}$ 可以被理解为一个特殊的矩阵乘法,所以你顺着搞下来和反着搞回去最后的结果是一样的,所以它是对的。

​然后下面贴了一道水题的代码来帮助理解:

例:​求有多少个从 $1,2,\cdots,n$ 中取三个元素的排列 $(a,b,c)$ 满足 $x_a=x_b-x_c$。​由于是排列,所以 $(a,b,c)$ 与 $(c,b,a)$ 视为两组解。

#include <algorithm>
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <cstring>
#include <ctime>
#include <iostream>
#include <queue>
#include <set>
#include <stack>

#define R register
#define ll long long
#define db double
#define ld long double
#define sqr(_x) (_x) * (_x)
#define Cmax(_a, _b) ((_a) < (_b) ? (_a) = (_b), 1 : 0)
#define Cmin(_a, _b) ((_a) > (_b) ? (_a) = (_b), 1 : 0)
#define Max(_a, _b) ((_a) > (_b) ? (_a) : (_b))
#define Min(_a, _b) ((_a) < (_b) ? (_a) : (_b))
#define Abs(_x) (_x < 0 ? (-(_x)) : (_x))

using namespace std;

namespace Dntcry
{
    inline int read()
    {
        R int a = 0, b = 1; R char c = getchar();
        for(; c < '0' || c > '9'; c = getchar()) (c == '-') ? b = -1 : 0;
        for(; c >= '0' && c <= '9'; c = getchar()) a = (a << 1) + (a << 3) + c - '0';
        return a * b;
    }
    inline ll lread()
    {
        R ll a = 0, b = 1; R char c = getchar();
        for(; c < '0' || c > '9'; c = getchar()) (c == '-') ? b = -1 : 0;
        for(; c >= '0' && c <= '9'; c = getchar()) a = (a << 1) + (a << 3) + c - '0';
        return a * b;
    }
    const int Maxn = 1000010, Maxl = 600010, lim = 100000;
    const ld pi = acos(-1);
    struct Complex
    {
        ld real, imag;
        Complex operator + (const Complex &b) const
        {
            return (Complex) {real + b.real, imag + b.imag};
        }
        Complex operator - (const Complex &b) const
        {
            return (Complex) {real - b.real, imag - b.imag};
        }
        Complex operator * (const Complex &b) const
        {
            return (Complex) {real * b.real - imag * b.imag, b.real * imag + real * b.imag};
        }
    }C[Maxl], A[Maxl], w[Maxl], wl;
    int n, m, x[Maxn], Cnt[Maxl], len, bit, rev[Maxl], zero;
    ll Ans[Maxn], Sum;
    void Get_Rev(R int bit)
    {
        for(R int i = 0; i < len; i++)
            rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << bit - 1);
        return ;
    }
    void FFT(R Complex *K, R ld DFT)
    {
        for(R int i = 0; i < len; i++) if(i < rev[i]) swap(K[i], K[rev[i]])
        R Complex *W;
        for(R int i = 2; i <= len; i <<= 1)
        {
            for(R int j = 0, step = i >> 1; j < len; j += i)
            {
                W = w + step;
                for(R int k = j; k < j + step; W++, k++)
                {
                    R Complex G = K[k], H = *W * K[k + step];
                    K[k] = G + H;
                    K[k + step] = G - H;
                }
            }
        }
        if(DFT == -1.0) 
            for(R int i = 0; i < len; i++) 
                K[i].real /= 1.0 * len, K[i].imag /= 1.0 * len;
        return ;
    }
    int Main()
    {
        n = read();
        for(R int i = 1; i <= n; i++)
        {
            x[i] = read(); if(!x[i]) zero++;
            x[i] += lim, m = Max(m, x[i]);
            Cnt[x[i]]++;
        } m++;
        for(bit = 0, len = 1; (1 << bit) < (m << 1); bit++) len <<= 1;
        R int tmp = len >> 1;
        w[tmp] = (Complex) {1.0, 0.0};
        wl = w[++tmp] = (Complex) {cos(2.0 * pi / len), sin(2.0 * pi / len)};
        for(tmp++; tmp < len; tmp++) w[tmp] = w[tmp - 1] * wl;
        for(R int i = (len >> 1) - 1; i; i--) w[i] = w[i << 1];
        Get_Rev(bit);
        for(R int i = 0; i < m; i++) A[i] = (Complex) {1.0 * Cnt[i], 0.0};
        FFT(A, 1.0);
        C[0] = A[0] * A[0];
        for(R int i = 1; i < len; i++) C[i] = A[len - i] * A[len - i];
        FFT(C, -1.0);
        for(R int i = 0; i < len; i++) Ans[i] = (ll)(C[i].real + 0.5);
        for(R int i = 1; i <= n; i++) Ans[x[i] << 1]--;
        for(R int i = 1; i <= n; i++) Sum += Ans[x[i] + lim];
        Sum -= 2ll * zero * (n - 1);
        printf("%lld\n", Sum);
        return 0; 
    }
}
int main()
{
    return Dntcry :: Main();
}

  

posted @ 2019-02-28 08:54  DntcryBecthlev  阅读(366)  评论(0编辑  收藏  举报