【luogu P4717】【模板】快速莫比乌斯/沃尔什变换 (FMT/FWT)(数学)
【模板】快速莫比乌斯/沃尔什变换 (FMT/FWT)
题目链接:luogu P4717
题目大意
给你两个长度为 2^n 的数组 A,B,设数组 C:
C[i]=sum{j⊕k=i}A[j]B[k]
分别当 ⊕ 是 or,and,xor 三种运算符时求出数组 C。
思路
嗯这里只会讲简单的 FWT,而且不说 FMT。
(毕竟听说 FMT 被 FWT 完爆所以就没学awa)
如果要看复杂的 FWT 可以看这位大佬的博客。
其实感觉这个 FWT 的过程其实就是类比了 FFT 这一类的方式。
\(c_i=\sum\limits_{i=j⊕ k}a_jb_k\)。
那我们先看就题目的三种运算。
或
\(c_i=\sum\limits_{i=j|k}a_jb_k\)
考虑搞点性质,发现 \(j|i=i,k|i=i\rightarrow(k|j)|i=i\)。
然后我们考虑构造 \(fwt[a/b/c]_i\),使得 \(fwt[a]_i*fwt[b]_i=fwt[c]_i\)。
然后可以找到 \(fwt[a]_i=\sum\limits_{j|i=i}a_j\)。
\(fwt[a]_i*fwt[b]_i=(\sum\limits_{j|i=i}a_j)(\sum\limits_{k|i=i}b_k)\)
\(=\sum\limits_{j|i=i}\sum\limits_{k|i=i}a_jb_k=\sum\limits_{(j|k)|i=i}a_jb_k=fwt[c]_i\)
然后接着就是考虑 \(fwt[a]_i\) 这些怎么快速求。
考虑先看下标二进制的最高位,然后让 \(a0\) 表示它下标最高位为 \(0\) 的那部分序列,\(a1\) 表示为 \(1\) 的那部分。
然后根据或的性质就有:
\(fwt[a]_i=(fwt[a0]_i,fwt[a0]+fwt[a1])\)
(就是这两个部分拼在一起,加号是每个位置加起来)
然后至于从 \(fwt[a]_i\) 求会 \(a\) 可以根据上面的反过来得到:
\(a=(a0,a1-a0)\)
然后你会发现它的形式很像 FFT/NTT 里面的,然后你就类似着搞就可以了。
与
跟着或的道理,你会发现是:
\(fwt[a]_i=(fwt[a0]_i+fwt[a1],fwt[a1])\)
\(a=(a0-a1,a1)\)
异或
这个就不一样的,考虑再找性质:
如果我们让 \(x⊕y=count1(x\&y)\bmod 2\)(\(count(1)\) 是二进制中 \(1\) 的个数)
然后你会发现有 \((i⊕j)xor(i⊕k)=i⊕(j\ xor\ k)\)(你分类讨论一下会发现确实是这样的)
然后你就构造,可以得到:
\(fwt[a]_i=\sum\limits_{i⊕j=0}a_j-\sum\limits_{i⊕j=1}a_j\)
然后:
\(fwt[a]=(fwt[a0]+fwt[a1],fwt[a0]-fwt[a1])\)
\(a=(\dfrac{a0+a1}{2},\dfrac{a0-a1}{2})\)
构造怎么弄的?
相信你看了这三个的构造 \(fwt[a]_i\) 数组的结果,会有那么一丝丝的疑惑:为啥能想到这个构造方法。
那我们就简单讲讲如何构造 FWT 中的这个数组。
首先我们用未知数表示:
\(fwt[a]_i=\sum\limits_{j=0}^{n-1}s(i,j)a_j\)
然后列出要求的式子:
\(fwt[a]_i*fwt[b]_i=fwt[c]_i\)
\(\sum\limits_{j=0}^{n-1}s(i,j)a_j\sum\limits_{k=0}^{n-1}s(i,k)b_k=\sum\limits_{p=0}^{n-1}s(i,p)c_p\)
\(\sum\limits_{j=0}^{n-1}\sum\limits_{k=0}^{n-1}s(i,j)s(i,k)a_jb_k=\sum\limits_{p=0}^{n-1}s(i,p)c_p\)
然后再有 \(a*b=c\):
\(c_p=\sum\limits_{j\oplus k=p}a_jb_k\)
\(\sum\limits_{p=0}^{n-1}s(i,p)c_p=\sum\limits_{p=0}^{n-1}s(i,p)\sum\limits_{j\oplus k=p}a_jb_k\)
\(\sum\limits_{j=0}^{n-1}\sum\limits_{k=0}^{n-1}s(i,j)s(i,k)a_jb_k=\sum\limits_{p=0}^{n-1}s(i,p)\sum\limits_{j\oplus k=p}a_jb_k=\sum\limits_{p=0}^{n-1}\sum\limits_{j\oplus k=p}a_jb_ks(i,j\oplus k)=\sum\limits_{j=0}^{n-1}\sum\limits_{k=0}^{n-1}a_jb_ks(i,j\oplus k)\)
\(\sum\limits_{j=0}^{n-1}\sum\limits_{k=0}^{n-1}s(i,j)s(i,k)a_jb_k=\sum\limits_{j=0}^{n-1}\sum\limits_{k=0}^{n-1}a_jb_ks(i,j\oplus k)\)
所以我们就需要让 \(s(i,j)s(i,k)=s(i,j\oplus k)\)。
接着就要用到 FWT 最特别的地方了:它是解决有关位运算的问题的。
也就是说它二进制每一位是互相独立的!
所以假设我们已经求出来对于一位的 \(s([0,1],[0,1])\),那我们就可以构造出所有的 \(s\):
设 \(a\) 二进制的每一位是 \(a_0,a_1,a_2,...\),那 \(s(i,j)=s(i_0,j_0)s(i_1,j_1)s(i_2,j_2)...\),就是每位的乘起来。
那么对于每一位:
\(s(i,j)s(i,k)=s(i,j\oplus k)\Leftrightarrow s(i_t,j_t)s(i_t,k_t)=s(i_t,j_t\oplus k_t)\)
那这个我们就可以每一位通过 \(0,1\) 的分类讨论求解,得到符合的 \(s\)。
而且在 or,and,xor 中它们的符合的 \(s\) 其实是有两种的,那我们选随便一个都可以用了。
比如 or 的是有这两种:
\(\begin{bmatrix}1&1\\1&0\end{bmatrix}\) 和 \(\begin{bmatrix}1 & 0\\ 1 & 1\end{bmatrix}\)
那构造就构造好啦!
\(fwt[a]\) 的求法感觉还是有点迷
那我们继续用上面的来:
\(fwt[a]_i=\sum\limits_{j=0}^{n-1}s(i,j)a_j\)
继续折半:\(\sum\limits_{j=0}^{n/2-1}s(i,j)a_j+\sum\limits_{j=n/2}^{n-1}s(i,j)a_j\)
然后也想前面那样 \(i'\) 为 \(i\) 去掉二进制首位的数。
\(\sum\limits_{j=0}^{n/2-1}s(i_0,j_0)s(i',j')a_j+\sum\limits_{j=n/2}^{n-1}s(i_0,j_0)s(i',j')a_j\)
\(\sum\limits_{j=0}^{n/2-1}s(i_0,0)s(i',j')a_j+\sum\limits_{j=n/2}^{n-1}s(i_0,1)s(i',j')a_j\)
那 \(c(i',j')\) 就是去掉首位的,规模自然减半。
当 \(i_0=0\) 即 \(0\leqslant i<n/2\),\(fwt[a]_i=s(0,0)fwt(a_0)_i+s(0,1)fwt(a_1)_i\)
当 \(i_0=1\) 即 \(n/2\leqslant <n\),\(fwt[a]_i=s(1,0)fwt(a_0)_i+s(1,1)fwt(a_1)_i\)
然后如果是 \(ifwt\)(就是从 \(fwt[a]\) 到 \(a\))就是把 \(s\) 这个矩阵求逆。
扩展一下,如果不是位运算还有可能用的上吗
其实是可以的,因为位运算你可以看做是 \(n\) 维 \(01\) 向量做运算:
or 是取 \(\max\),and 是取 \(\min\),xor 是每一维相加的结果 \(\bmod\ 2\)。
那我们可以扩展到 \([0,k)\),那我们要的 \(s\) 就是一个 \(k*k\) 的矩阵,然后暴力算是 \(k^{n+1}n\)。
然后这些矩阵也可以快速算,\(\max\min\) 是高位前缀和压掉一个 \(k\),\(\mod k\) 的话列 \(s\) 可以用范德蒙德矩阵。
然后可以用 FTT 把一个 \(k\) 变成 \(\log k\)。
r然而又因为单位根模的意义下可能不存在,所以你要通过再来一个 \(k\) 的复杂度以及一通神仙操作(别想了我不可能会的去看那位大佬的博客吧)
代码
#include<cstdio>
#include<cstring>
#define mo 998244353
#define cpy(f, g, n) memcpy(f, g, sizeof(int) * (n))
#define clr(f, n) memset(f, 0, sizeof(int) * (n))
using namespace std;
const int N = (1 << 17);
int n, f[N], g[N], inv2, tmp[N];
int jia(int x, int y) {return x + y >= mo ? x + y - mo : x + y;}
int jian(int x, int y) {return x - y < 0 ? x - y + mo : x - y;}
int cheng(int x, int y) {return 1ll * x * y % mo;}
void px(int *f, int *g, int n) {
for (int i = 0; i < n; i++)
f[i] = cheng(f[i], g[i]);
}
void FWT_or(int *f, int n, int op) {
for (int mid = 1; mid < n; mid <<= 1)
for (int j = 0; j < n; j += (mid << 1))
for (int k = 0; k < mid; k++) {
int x = f[j | k], y = f[j | mid | k];
f[j | k] = x; f[j | mid | k] = jia(cheng((op == 1) ? 1 : mo - 1, x), y);
}
}
void FWT_and(int *f, int n, int op) {
for (int mid = 1; mid < n; mid <<= 1)
for (int j = 0; j < n; j += (mid << 1))
for (int k = 0; k < mid; k++) {
int x = f[j | k], y = f[j | mid | k];
f[j | k] = jia(x, cheng((op == 1) ? 1 : mo - 1, y)); f[j | mid | k] = y;
}
}
void FWT_xor(int *f, int n, int op) {
for (int mid = 1; mid < n; mid <<= 1)
for (int j = 0; j < n; j += (mid << 1))
for (int k = 0; k < mid; k++) {
int x = f[j | k], y = f[j | mid | k];
f[j | k] = jia(x, y); f[j | mid | k] = jian(x, y);
if (op == -1) f[j | k] = cheng(f[j | k], inv2), f[j | mid | k] = cheng(f[j | mid | k], inv2);
}
}
void cheng_or(int *f, int *g, int n) {
static int Tmp[N];
cpy(Tmp, g, n);
FWT_or(f, n, 1); FWT_or(Tmp, n, 1);
px(f, Tmp, n); FWT_or(f, n, -1);
clr(Tmp, n);
}
void cheng_and(int *f, int *g, int n) {
static int tmp[N];
cpy(tmp, g, n);
FWT_and(f, n, 1); FWT_and(tmp, n, 1);
px(f, tmp, n); FWT_and(f, n, -1);
clr(tmp, n);
}
void cheng_xor(int *f, int *g, int n) {
static int tmp[N];
cpy(tmp, g, n);
FWT_xor(f, n, 1); FWT_xor(tmp, n, 1);
px(f, tmp, n); FWT_xor(f, n, -1);
clr(tmp, n);
}
int main() {
scanf("%d", &n); n = (1 << n);
for (int i = 0; i < n; i++) scanf("%d", &f[i]);
for (int i = 0; i < n; i++) scanf("%d", &g[i]);
inv2 = (mo + 1) / 2;
cpy(tmp, f, n); cheng_or(tmp, g, n); for (int i = 0; i < n; i++) printf("%d ", tmp[i]); puts("");
cpy(tmp, f, n); cheng_and(tmp, g, n); for (int i = 0; i < n; i++) printf("%d ", tmp[i]); puts("");
cpy(tmp, f, n); cheng_xor(tmp, g, n); for (int i = 0; i < n; i++) printf("%d ", tmp[i]); puts("");
return 0;
}