FWT快速沃尔什变换

FWT

\[C(i) = \sum_{j @ k=i}A(j)B(k) \]

DWT

设第 \(i\) 个点值 \(x^j\) 带入的是 \(f(i,j)\),由于位运算乘法需满足 \(x^ix^j=x^{i@j}\)

所以 \(f(i,j)\) 需满足:

\[f(i,j)f(i,k)=f(i,j@k) \]

以此为依据构造 \(f(i,j)\)

对于 and, \(f(i, j) = [i \&j=i]\)\([i\subseteq j]\)

对于 or, \(f(i,j)=[i|j=i]\)\([j\subseteq i]\)

对于 xor, \(f(i,j)=(-1)^{popcount(i\&j)}\)

\(f(i,j)=\prod f(i_p,j_p)\) \(i_p\)\(i\) ,2进制下的每一位。

\[\begin{aligned} DWT(C)_i &= \sum_{j=0}^{n-1}f(i,j)C(j)\\ &=\sum_{j=0}^{n/2-1}f(i,j)C(j)+\sum_{j=n/2}^{n-1}f(i,j)C(j)\\ &=\sum_{j=0}^{n/2-1}f(i_{len},j_{len})f(i_{[1,len-1]},j_{[1,len-1]})C(j)+\sum_{j=n/2}^{n-1}f(i_{len},j_{len})f(i_{[1,len-1]},j_{[1,len-1]})C(j)\\ &=f(i_{len},0)\sum_{j=0}^{n/2-1}f(i_{[1,len-1]},j_{[1,len-1]})C(j)+f(i_{len},1)\sum_{j=n/2}^{n-1}f(i_{[1,len-1]},j_{[1,len-1]})C(j) \end{aligned} \]

对于 \(i\in[0, n/2-1]\) ,

\[DWT(C)_i = f(0,0)DWT(C_L)_i + f(0,1)DWT(C_R)_i\\ DWT(C)_{i+n/2}=f(1, 0)DWT(C_L)_{i}+f(1,1)DWT(C_R)_i \]

到最底层时 \(DWT(C)_0=\sum_{j=0}^0f(0,j)C(j)=f(0,0)C(j)\)\(f(0,0)\) 对于 xor, and, or 都是 1。所以递归到最底层 的点值就是自己本身的系数。

然后就可以左右分治了,值得注意的是DFT是奇偶分治,而DWT是左右。

and:

\[DWT(C)_i = DWT(C_L)_i+DWT(C_R)_i\\ DWT(C)_{i+n/2}=DWT(C_R)_i \]

or:

\[DWT(C)_i = DWT(C_L)_i\\ DWT(C)_{i+n/2}=DWT(C_L)_i+DWT(C_R)_i \]

xor:

\[DWT(C)_i = DWT(C_L)_i + DWT(C_R)_i\\ DWT(C)_{i+n/2}=DWT(C_L)_{i}-DWT(C_R)_i \]

IDWT

由DWT的式子可以解得:

and:

\[DWT(C_L)_i=DWT(C)_i-DWT(C)_{i+n/2}\\ DWT(C_R)_i=DWT(C)_{i+n/2} \]

or:

\[DWT(C_L)_i = DWT(C)_i\\ DWT(C_R)_i=DWT(C)_{i+n/2}-DWT(C)_i \]

xor:

\[DWT(C_L)_i=\frac{DWT(C)_i+DWT(C)_{i+n/2}}{2}\\ DWT(C_R)_i=\frac{DWT(C)_i-DWT(C)_{i+n/2}}{2} \]

#include <vector>
#include <cmath>
#include <cstdio>
#include <cassert>
#include <cstring>
#include <iostream>
#include <algorithm>

typedef long long LL;
typedef unsigned long long uLL;

#define fir first
#define sec second
#define SZ(x) (int)x.size()
#define MP(x, y) std::make_pair(x, y)
#define PB(x) push_back(x)
#define debug(...) fprintf(stderr, __VA_ARGS__)
#define GO debug("GO\n")
#define rep(i, a, b) for (register int i = (a), i##end = (b); (i) <= i##end; ++ (i))
#define drep(i, a, b) for (register int i = (a), i##end = (b); (i) >= i##end; -- (i))
#define REP(i, a, b) for (register int i = (a), i##end = (b); (i) < i##end; ++ (i))

inline int read() {
    register int x = 0; register int f = 1; register char c;
    while (!isdigit(c = getchar())) if (c == '-') f = -1;
    while (x = (x << 1) + (x << 3) + (c xor 48), isdigit(c = getchar()));
    return x * f;
}
template<class T> inline void write(T x) {
    static char stk[30]; static int top = 0;
    if (x < 0) { x = -x, putchar('-'); }
    while (stk[++top] = x % 10 xor 48, x /= 10, x);
    while (putchar(stk[top--]), top);
}
template<typename T> inline bool chkmin(T &a, T b) { return a > b ? a = b, 1 : 0; }
template<typename T> inline bool chkmax(T &a, T b) { return a < b ? a = b, 1 : 0; }

using namespace std;

const int N = 131082;
const int P = 998244353;
const int inv2 = (P + 1) >> 1;

LL ADD(LL a, int b) {
    return a + b >= P ? a + b - P : a + b;
}

void DWT(int a[], int n, int t) {
    if (t > 0) {
        for (register int len = 2; len <= n; len <<= 1) {
            int m = len >> 1;
            for (register int* p = a; p != a + n; p += len) 
                for (register int i = 0; i < m; ++i) {
                    register int x = p[i], y = p[i + m];
                    if (t == 1) { //or 
                        p[i] = x;
                        p[i + m] = ADD(x, y);
                    } else if (t == 2) { //and
                        p[i] = ADD(x,  y);
                        p[i + m] = y;
                    } else { //xor
                        p[i] = ADD(x, y);
                        p[i + m] = ADD(1ll * x - y, P);
                    }
                }
        }
    } else {
        for (register int len = n; len >= 2; len >>= 1) {
            register int m = len >> 1;
            for (register int *p = a; p != a + n; p += len) 
                for (register int i = 0; i < m; ++i) {
                    register int x = p[i], y = p[i + m];
                    if (t == -1) {
                        p[i] = x;
                        p[i + m] = ADD(1ll * y - x, P);
                    } else if (t == -2) {
                        p[i] = ADD(1ll * x - y, P);
                        p[i + m] = y;
                    } else {
                        p[i] = ADD(x, y) * inv2 % P;
                        p[i + m] = ADD(1ll * x - y, P) * inv2 % P;
                    }
                }
        }
    }
}

int main() 
{
#ifndef ONLINE_JUDGE
    freopen("xhc.in", "r", stdin);
    freopen("xhc.out", "w", stdout);
#endif
    int lg2 = read();
    int n = 1 << lg2;
    static int A[N], B[N], C[N];
    rep (i, 0, n - 1) A[i] = read();
    rep (i, 0, n - 1) B[i] = read();

    DWT(A, n, 1), DWT(B, n, 1);
    rep (i, 0, n - 1) C[i] = 1ll * A[i] * B[i] % P;
    DWT(C, n, -1), DWT(A, n, -1), DWT(B, n, -1);
    rep (i, 0, n - 1) write(C[i]), putchar(' ');
    putchar('\n');

    DWT(A, n, 2), DWT(B, n, 2);
    rep (i, 0, n - 1) C[i] = 1ll * A[i] * B[i] % P;
    DWT(C, n, -2), DWT(A, n, -2), DWT(B, n, -2);
    rep (i, 0, n - 1) write(C[i]), putchar(' ');
    putchar('\n');

    DWT(A, n, 3), DWT(B, n, 3);
    rep (i, 0, n - 1) C[i] = 1ll * A[i] * B[i] % P;
    DWT(C, n, -3), DWT(A, n, -3), DWT(B, n, -3);
    rep (i, 0, n - 1) write(C[i]), putchar(' ');
    putchar('\n');

    return 0;
}

posted @ 2019-07-29 10:47  茶Tea  阅读(129)  评论(0编辑  收藏  举报