浅谈FWT
这种东西直接背板子就好了
感觉如果懂FFT的话FWT还是很好理解的
也是转换为点值表示法(FWT)然后对应位乘起来就好了
然后再IFWT就可以了
像我就直接背板子
A 0 是 二 进 制 开 头 为 0 的 , A 1 是 二 进 制 开 头 为 1 的 A_0是二进制开头为0的,A_1是二进制开头为1的 A0是二进制开头为0的,A1是二进制开头为1的
or
如果卷积长这样 C k = ∑ i ∣ j = k A i ∗ B j C_k=∑_{i|j=k} A_i ∗ B_j Ck=∑i∣j=kAi∗Bj
F W T ( A ) = ( F W T ( A 0 ) , F W T ( A 0 + A 1 ) ) FWT(A) = (FWT(A_0) , FWT(A_0 + A_1)) FWT(A)=(FWT(A0),FWT(A0+A1))
I F W T ( A ) = ( F W T ( A 0 ) , F W T ( A 0 − A 1 ) ) IFWT(A) = (FWT(A_0) , FWT(A_0 - A_1)) IFWT(A)=(FWT(A0),FWT(A0−A1))
and
如果卷积长这样 C k = ∑ i & j = k A i ∗ B j C_k=∑_{i\&j=k} A_i ∗ B_j Ck=∑i&j=kAi∗Bj
F W T ( A ) = ( F W T ( A 0 + A 1 ) , F W T ( A 1 ) ) FWT(A) = (FWT(A_0 + A_1) , FWT(A_1)) FWT(A)=(FWT(A0+A1),FWT(A1))
I F W T ( A ) = ( F W T ( A 0 − A 1 ) , F W T ( A 1 ) ) IFWT(A) = (FWT(A_0 - A_1) , FWT(A_1)) IFWT(A)=(FWT(A0−A1),FWT(A1))
xor
如果卷积长这样 C k = ∑ i ⊕ j = k A i ∗ B j C_k=∑_{i ⊕j=k} A_i ∗ B_j Ck=∑i⊕j=kAi∗Bj
F W T ( A ) = ( F W T ( A 0 + A 1 ) , F W T ( A 0 − A 1 ) ) FWT(A) = (FWT(A_0 + A_1) , FWT(A_0 - A_1)) FWT(A)=(FWT(A0+A1),FWT(A0−A1))
I F W T ( A ) = ( F W T ( A 0 + A 1 ) / 2 , F W T ( A 0 − A 1 ) / 2 ) IFWT(A) = (FWT(A_0 + A_1) / 2 , FWT(A_0 - A_1) / 2) IFWT(A)=(FWT(A0+A1)/2,FWT(A0−A1)/2)
然后就和FFT一样用就行了
code:
#include<bits/stdc++.h>
#define mod 998244353
#define ll long long
using namespace std;
const int N = (1 << 17) + 5;
void fwt_or(ll *a, int len, int opt){
for(int i = 2; i <= len; i <<= 1)
for(int p = i >> 1, j = 0; j + i <= len; j += i)
for(int k = j; k < j + p; k ++)
a[p + k] =(a[p + k] + opt * a[k] + mod) % mod;
}
void fwt_and(ll *a, int len, int opt){
for(int i = 2; i <= len; i <<= 1)
for(int p = i >> 1, j = 0; j + i <= len; j += i)
for(int k = j; k < j + p; k ++)
a[k] = (a[k] + opt * a[k + p] + mod) % mod;
}
void fwt_xor(ll *a, int len, int opt){
for(int i = 2; i <= len; i <<= 1)
for(int p = i >> 1, j = 0; j + i <= len; j += i)
for(int k = j; k < j + p; k ++){
int X = a[k], Y = a[k + p];
a[k] = (X + Y) % mod, a[k + p] = (X - Y + mod) % mod;
if(opt == -1) a[k] = a[k] * 499122177 % mod, a[k + p] = a[k + p] * 499122177 % mod;
}
}
int n;
ll a[N], b[N], f[N], g[N];
int main(){
scanf("%d", &n);
int len = (1 << n);
for(int i = 0; i < len; i ++) scanf("%lld", &a[i]);
for(int i = 0; i < len; i ++) scanf("%lld", &b[i]);
for(int i = 0; i < len; i ++) f[i] = a[i], g[i] = b[i];
fwt_or(f, len, 1), fwt_or(g, len, 1);
for(int i = 0; i < len; i ++) f[i] = f[i] * g[i] % mod;
fwt_or(f, len, -1);
for(int i = 0; i < len; i ++) printf("%lld ", f[i]); printf("\n");
for(int i = 0; i < len; i ++) f[i] = a[i], g[i] = b[i];
fwt_and(f, len, 1), fwt_and(g, len, 1);
for(int i = 0; i < len; i ++) f[i] = f[i] * g[i] % mod;
fwt_and(f, len, -1);
for(int i = 0; i < len; i ++) printf("%lld ", f[i]); printf("\n");
for(int i = 0; i < len; i ++) f[i] = a[i], g[i] = b[i];
fwt_xor(f, len, 1), fwt_xor(g, len, 1);
for(int i = 0; i < len; i ++) f[i] = f[i] * g[i] % mod;
fwt_xor(f, len, -1);
for(int i = 0; i < len; i ++) printf("%lld ", f[i]); printf("\n");
return 0;
}
咕咕咕咕