FWT 学习笔记
FWT用来干什么
快速处理
\[c[k] = \sum_{i\ or|and|xor\ j = k} a[i] * b[j]
\]
-
记号:\(a + b\) 表示 \(a,b\) 逐位相加 \((a[i] + b[i])\)
-
记号:\(a * b\) 表示 \(a\) 卷 \(b\)
-
这种卷积具有乘法分配律 \((a + b) * c = a * c + b * c\)
根据最高位为 \(1\ or\ 0\) 将多项式 \(a, b, c\) 分为 \(a_0, a_1, b_0, b_1, c_0, c_1\)
这样就可以不考虑最高位
or
\[c_0 = a_0 * b_0
\]
因为一旦最高位为 1 权值就统计到 \(c_1\) 上了
\[c_1 = a_0 * b_1 + a_1 * b_0 + a_1 * b_1
\]
\[= (a_0 + a_1) * (b_0 + b_1) - a_0 *b_0
\]
\[= (a_0 + a_1) * (b_0 + b_1) - c_0
\]
这样问题就缩小一倍了
边界:\(c[0] = a[0] * b[0]\)
and
\[c_1 = a1 * b_1
\]
一旦最高位不为 1 权值就统计到 \(c_0\) 上了
\[c_0 = a0 * b_1 + b_0 * a_1 + a_0 * b_0
\]
\[= (a_0 + a_1) * (b_0 + b_1) - c_1
\]
xor
\[c_0 = a_0 * b_0 + a_1 * b_1
\]
\[c_1 = a_1 * b_0 + a_0 * b_1
\]
让 \(x_0 = (a_0 + a_1) * (b_0 + b_1), x_1 = (a_0 - a_1) * (b_0 - b_1)\)
这样 \(c_0 = \frac {x_0 + x_1} 2, c_1 = \frac {x_0 - x_1} 2\)
然后就就没了。。
\(code\)
#include <bits/stdc++.h>
using namespace std;
#define rg register
inline int read(){
#define gc getchar
rg char ch = gc();
rg int x = 0, f = 0;
while(!isdigit(ch)) f |= (ch == '-'), ch = gc();
while(isdigit(ch)) x = (x << 1) + (x << 3) + (ch ^ 48), ch =gc();
return f ? -x : x;
}
const int N = 1 << 17;
int a[N], b[N], c[N],
d[N], e[N], f[N],
g[N], h[N], i[N];
int n;
const int mod = 998244353, inv2 = (mod + 1) >> 1;
inline void Mod(int &x){
x += (x >> 31 & mod);
}
inline void mulor(int *a, int *b, int *c, int lim){
if(!(lim >>= 1)) return (void) (*c = 1ll * (*a) * (*b) % mod);
for(int i = 0; i < lim; ++i) Mod(a[i + lim] += a[i] - mod), Mod(b[i + lim] += b[i] - mod);
mulor(a, b, c, lim); mulor(a + lim, b + lim, c + lim, lim);
for(int i = 0; i < lim; ++i) Mod(c[i + lim] -= c[i]);
}
inline void muland(int *a, int *b, int *c, int lim){
if(!(lim >>= 1)) return (void) (*c = 1ll * (*a) * (*b) % mod);
for(int i = 0; i < lim; ++i) Mod(a[i] += a[i + lim] - mod), Mod(b[i] += b[i + lim] - mod);
muland(a, b, c, lim); muland(a + lim, b + lim, c + lim, lim);
for(int i = 0; i < lim; ++i) Mod(c[i] -= c[i + lim]);
}
inline void mulxor(int *a, int *b, int *c, int lim){
if(!(lim >>= 1)) return (void) (*c = 1ll * (*a) * (*b) % mod);
for(int i = 0; i < lim; ++i){
// a[i] += a[i + lim];
// a[i + lim] = a[i] - (a[i + lim] << 1);
// b[i] += b[i + lim];
// b[i + lim] = b[i] - (b[i + lim] << 1);
// 上下等价
tie(a[i], a[i + lim]) = make_tuple(a[i] + a[i + lim], a[i] - a[i + lim]);
tie(b[i], b[i + lim]) = make_tuple(b[i] + b[i + lim], b[i] - b[i - lim]);
Mod(a[i + lim]); Mod(b[i + lim]); Mod(a[i] -= mod); Mod(b[i] -= mod);
}
mulxor(a, b, c, lim); mulxor(a + lim, b + lim, c + lim, lim);
for(int i = 0; i < lim; ++i) tie(c[i], c[i + lim])
= make_tuple(1ll * (c[i] + c[i + lim]) * inv2 % mod, 1ll * (c[i] - c[i + lim] + mod) * inv2 % mod);
}
signed main(){
n = read();
int lim = 1 << n;
for(int i = 0; i < lim; ++i) a[i] = d[i] = g[i] = read();
for(int i = 0; i < lim; ++i) b[i] = e[i] = h[i] = read();
mulor(a, b, c, lim);
muland(d, e, f, lim);
mulxor(g, h, i, lim);
for(int i = 0; i < lim; ++i) printf("%d ", c[i]); puts("");
for(int i = 0; i < lim; ++i) printf("%d ", f[i]); puts("");
for(int j = 0; j < lim; ++j) printf("%d ", i[j]); puts("");
gc(), gc();
return 0;
}