FWT 学习笔记

FWT用来干什么

快速处理

\[c[k] = \sum_{i\ or|and|xor\ j = k} a[i] * b[j] \]

  • 记号:\(a + b\) 表示 \(a,b\) 逐位相加 \((a[i] + b[i])\)

  • 记号:\(a * b\) 表示 \(a\)\(b\)

  • 这种卷积具有乘法分配律 \((a + b) * c = a * c + b * c\)

根据最高位为 \(1\ or\ 0\) 将多项式 \(a, b, c\) 分为 \(a_0, a_1, b_0, b_1, c_0, c_1\)

这样就可以不考虑最高位

or

\[c_0 = a_0 * b_0 \]

因为一旦最高位为 1 权值就统计到 \(c_1\) 上了

\[c_1 = a_0 * b_1 + a_1 * b_0 + a_1 * b_1 \]

\[= (a_0 + a_1) * (b_0 + b_1) - a_0 *b_0 \]

\[= (a_0 + a_1) * (b_0 + b_1) - c_0 \]

这样问题就缩小一倍了

边界:\(c[0] = a[0] * b[0]\)

and

\[c_1 = a1 * b_1 \]

一旦最高位不为 1 权值就统计到 \(c_0\) 上了

\[c_0 = a0 * b_1 + b_0 * a_1 + a_0 * b_0 \]

\[= (a_0 + a_1) * (b_0 + b_1) - c_1 \]

xor

\[c_0 = a_0 * b_0 + a_1 * b_1 \]

\[c_1 = a_1 * b_0 + a_0 * b_1 \]

\(x_0 = (a_0 + a_1) * (b_0 + b_1), x_1 = (a_0 - a_1) * (b_0 - b_1)\)

这样 \(c_0 = \frac {x_0 + x_1} 2, c_1 = \frac {x_0 - x_1} 2\)

然后就就没了。。

\(code\)

#include <bits/stdc++.h>
using namespace std;
#define rg register
inline int read(){
    #define gc getchar
    rg char ch = gc();
    rg int x = 0, f = 0;
    while(!isdigit(ch)) f |= (ch == '-'), ch = gc();
    while(isdigit(ch)) x = (x << 1) + (x << 3) + (ch ^ 48), ch =gc();
    return f ? -x : x;
}
const int N = 1 << 17;
int a[N], b[N], c[N], 
d[N], e[N], f[N], 
g[N], h[N], i[N];
int n;
const int mod = 998244353, inv2 = (mod + 1) >> 1;
inline void Mod(int &x){
    x += (x >> 31 & mod);
}
inline void mulor(int *a, int *b, int *c, int lim){
    if(!(lim >>= 1)) return (void) (*c = 1ll * (*a) * (*b) % mod);
    for(int i = 0; i < lim; ++i) Mod(a[i + lim] += a[i] - mod), Mod(b[i + lim] += b[i] - mod);
    mulor(a, b, c, lim); mulor(a + lim, b + lim, c + lim, lim);
    for(int i = 0; i < lim; ++i) Mod(c[i + lim] -= c[i]);
}
inline void muland(int *a, int *b, int *c, int lim){
    if(!(lim >>= 1)) return (void) (*c = 1ll * (*a) * (*b) % mod);
    for(int i = 0; i < lim; ++i) Mod(a[i] += a[i + lim] - mod), Mod(b[i] += b[i + lim] - mod);
    muland(a, b, c, lim); muland(a + lim, b + lim, c + lim, lim);
    for(int i = 0; i < lim; ++i) Mod(c[i] -= c[i + lim]);
}
inline void mulxor(int *a, int *b, int *c, int lim){
    if(!(lim >>= 1)) return (void) (*c = 1ll * (*a) * (*b) % mod);
    for(int i = 0; i < lim; ++i){
//      a[i] += a[i + lim];
//      a[i + lim] = a[i] - (a[i + lim] << 1);
//      b[i] += b[i + lim];
//      b[i + lim] = b[i] - (b[i + lim] << 1);
//		上下等价 
        tie(a[i], a[i + lim]) = make_tuple(a[i] + a[i + lim], a[i] - a[i + lim]);
        tie(b[i], b[i + lim]) = make_tuple(b[i] + b[i + lim], b[i] - b[i - lim]);
        Mod(a[i + lim]); Mod(b[i + lim]); Mod(a[i] -= mod); Mod(b[i] -= mod);
    }
    mulxor(a, b, c, lim); mulxor(a + lim, b + lim, c + lim, lim);
    for(int i = 0; i < lim; ++i) tie(c[i], c[i + lim]) 
        = make_tuple(1ll * (c[i] + c[i + lim]) * inv2 % mod, 1ll * (c[i] - c[i + lim] + mod) * inv2 % mod);
}
signed main(){
    n = read();
    int lim = 1 << n;
    for(int i = 0; i < lim; ++i) a[i] = d[i] = g[i] = read();
	for(int i = 0; i < lim; ++i) b[i] = e[i] = h[i] = read();
    mulor(a, b, c, lim);
    muland(d, e, f, lim);
    mulxor(g, h, i, lim);
    for(int i = 0; i < lim; ++i) printf("%d ", c[i]); puts("");
    for(int i = 0; i < lim; ++i) printf("%d ", f[i]); puts("");
    for(int j = 0; j < lim; ++j) printf("%d ", i[j]); puts("");
    gc(), gc();
    return 0;
}
posted @ 2020-06-06 11:55  __int256  阅读(141)  评论(0编辑  收藏  举报