FWT学习笔记
快速沃尔什变换FWT
概述
FFT 是加速这样的卷积 :
流程是这样的:
- \(A, B \Rightarrow FFT(A), FFT(B)\) (带入单位根)
- \(FFT(A) \cdot FFT(B)\) (点乘)
- \(IFFT(FFT(A) \cdot FFT(B)) \Rightarrow A * B\)
而对于位运算的卷积, \(\oplus\) 暂指异或
通过构造相应的 \(FWT(A)\), 尝试以同样的步骤快速得到卷积
构造的流程大概是这样的:
- 构造 FWT 的通式
- 得到 FWT 分治递归式
- 通过上式得到 IFWT 的分治递归式
时刻注意 FWT 是系数转点值的作用, IFWT 相反,
然后诸多性质可以感性得到, 因为 FWT 变换没有 FFT 那样复杂, 就不需要大量公式的推导
or 卷积
构造这样的 FWT :
来验证一下:
- \((A | B)_i\) 由下标为 \(i\) 的子集的 \((A|B)_j\) 相加, 而他们由两个或起来是 \(j\) 的两个下标相乘, 下标也是 \(i\) 的子集.
- 下标是 \(i\) 的子集的 \(A_j \cdot B_k\), 或起来只能组成 \(i\) 的子集.
- 不难看出 1 和 2 中所指的两个下标互为充要条件
再形象一点?
正向逆向都走得通
(先跳到 and 卷积的 FWT 构造)
接着看如何分治求 FWT 和 IFWT:
将 \(A\) 按最高二进制位是 0/1 分成 \(A_0, A_1\), 其实就是前一半后一半
根据通式, \(FWT(A_0)\) 只能来自 \(A_0\), 也就是说 \(FWT(A)\) 的前一半就是 \(FWT(A_0)\),
对于 \(FWT(A_1)\), 一部分来自最高位不是 0 的下标, 另一部分反之, 前者即为 \(A_0\), 后者为 \(A_1\),
考虑 \(A_0\) 和 \(A_1\) 去掉最高位后依次对应, 易得 \(FWT(A)_{n / 2 + i} = FWT(A_0)_i + FWT(A_1)_i\)
即, \(FWT(A)\) 前一半为 \(FWT(A_0)\), 后一半为 \(FWT(A_0) + FWT(A_1)\)
然后考虑 IFWT:
同样 \(A_0, A_1\), 现在将一组点值转回去
显然前一半的点值直接由 \(IFWT(A_0)\) 得到,
后一半系数呢? 既然合法的点值与系数组组对应, 那么可以有它生成的点值得到, 即 \(IFWT(A_1 - A_0)\),
然而这个式子并不和谐, 考虑 "一组系数带入一个值 - 另一组系数带入这个值 = (两组系数相减)带入这个值", \(IFWT(A_1 - A_0) = IFWT(A_1) - IFWT(A_0)\)
即, \(IFWT(A)\) 前一半为 \(IFWT(A_0)\), 后一半为 \(IFWT(A_1) - IFWT(A_0)\)
and 卷积
模仿 or 来构造 FWT
同样是符合这个的
证明方法类似
接下来还是很像(想一想)
xor 卷积
很显然这个 FWT 的通式不是之前的套路
先咕着... https://blog.csdn.net/xyyxyyx/article/details/103564869
其中 \({|x|}\) 表示二进制 1 位的个数
考虑递归式 FWT:
先考虑 \(A_0\), \(FWT(A_0)\) 不用说了, 但是 \(j\) 也可以取右半边的, 那么对应位的 \(A_1\) 比自己多一位了, 分类讨论可得 \(FWT(A_1)\) 贡献正负性不变
在考虑 \(A_1\), 同理, 但是 \(A_0\) 比自己少一位, 分类讨论可得 \(FWT(A_0)\) 的正负性不变, 但是自己的 \(FWT(A_1)\) 贡献正负性变了
即, \(FWT(A) = merge(FWT(A_0) + FWT(A_1), FWT(A_0)) - FWT(A_1))\)
然后很容易得到
代码
luogu 的板子题
#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long LL;
inline LL in()
{
LL x = 0, flag = 1; char ch = getchar();
while (!isdigit(ch)) { if (ch == '-') flag = -1; ch = getchar(); }
while (isdigit(ch)) x = (x << 3) + (x << 1) + (ch ^ 48), ch = getchar();
return x * flag;
}
typedef double LDB;
const int MAXN = (1 << 20) + 10;
int lowbit(int x) { return x & (-x); }
int n, len;
const LL MOD = 998244353, inv2 = 499122177;
void FWT_or(LL * a, int len, int sgn)
{
// FWT(A) = merge(FWT(A0) + FWT(A1), FWT(A1)) ;
for (int mid = 1; mid < len; mid <<= 1)
for (int i = 0; i < len; i += (mid << 1))
for (int j = 0; j < mid; ++ j)
{
if (sgn == -1) (a[i + j + mid] += MOD - a[i + j]) %= MOD;
else (a[i + j + mid] += a[i + j]) %= MOD;
}
}
void FWT_and(LL * a, int len, int sgn)
{
// FWT(A) = merge(FWT(A0, FWT(A0) + FWT(A1)) ;
for (int mid = 1; mid < len; mid <<= 1)
for (int i = 0; i < len; i += (mid << 1))
for (int j = 0; j < mid; ++ j)
{
if (sgn == -1) (a[i + j] += MOD - a[i + j + mid]) %= MOD;
else (a[i + j] += a[i + j + mid]) %= MOD;
}
}
void FWT_xor(LL * a, int len, int sgn)
{
// FWT(A) = merge(FWT(A0) + FWT(A1), FWT(A0) - FWT(A1)) ;
// IFWT(A) = merge((IFWT(A0) + IFWT(A1)) / 2, (IFWT(A0) - IFWT(A1)) / 2) ;
for (int mid = 1; mid < len; mid <<= 1)
for (int i = 0; i < len; i += (mid << 1))
for (int j = 0; j < mid; ++ j)
{
LL x = a[i + j], y = a[i + j + mid];
a[i + j] = (x + y) % MOD, a[i + j + mid] = (x + MOD - y) % MOD;
if (sgn == -1)
(a[i + j] *= inv2) %= MOD, (a[i + j + mid] *= inv2) %= MOD;
}
}
LL reca[MAXN], recb[MAXN];
LL a[MAXN], b[MAXN];
void solve_or()
{
for (int i = 0; i < len; ++ i) a[i] = reca[i], b[i] = recb[i];
FWT_or(a, len, 1); FWT_or(b, len, 1);
for (int i = 0; i < len; ++ i) (a[i] *= b[i]) %= MOD;
FWT_or(a, len, -1);
for (int i = 0; i < len; ++ i) printf("%lld ", a[i]); puts("");
}
void solve_and()
{
for (int i = 0; i < len; ++ i) a[i] = reca[i], b[i] = recb[i];
FWT_and(a, len, 1); FWT_and(b, len, 1);
for (int i = 0; i < len; ++ i) (a[i] *= b[i]) %= MOD;
FWT_and(a, len, -1);
for (int i = 0; i < len; ++ i) printf("%lld ", a[i]); puts("");
}
void solve_xor()
{
for (int i = 0; i < len; ++ i) a[i] = reca[i], b[i] = recb[i];
FWT_xor(a, len, 1); FWT_xor(b, len, 1);
for (int i = 0; i < len; ++ i) (a[i] *= b[i]) %= MOD;
FWT_xor(a, len, -1);
for (int i = 0; i < len; ++ i) printf("%lld ", a[i]); puts("");
}
int main()
{
n = in();
len = 1 << n;
for (int i = 0; i < len; ++ i) reca[i] = in();
for (int i = 0; i < len; ++ i) recb[i] = in();
solve_or();
solve_and();
solve_xor();
return 0;
}