FWT小记
前言
这是博主最后一年寒假时候学的,仅为了自己复习写的,所以不够详细。等有时间了大概会来补充完善一下。
这篇博客或许不错
解决问题
\(c_i=\sum\limits_{j\circ k=i}a_j\times b_k\)
核心思想
对于 \(a,b\) 找到一个可逆变换,使得可以将变换后的 \(a,b\) 直接点乘得到变换后的 \(c\) ,然后逆运算回来。
或卷积
\(f_i=\sum\limits_{j \ |\ i=i} a_j\),\(g_i=\sum\limits_{j \ |\ i=i} b_j\),\(h_i=f_i \times g_i\)
\(h_i=\sum\limits_{j\ |\ i = i} \ \sum\limits_{k\ |\ i=i} a_j\times b_k=\sum\limits_{(j\ |\ k)\ | \ i=i}a_j\times b_k\)
所以 \(h\) 就是 \(c\) 的变换后的数组。此时对 \(h\) 做一个高维差分即可得到 \(c\) 。
与卷积
\(f_i=\sum\limits_{j \ \&\ i=i} a_j\),\(g_i=\sum\limits_{j \ \&\ i=i} b_j\),\(h_i=f_i \times g_i\)
对 \(h\) 做一个高维后缀差分。
异或卷积
定义 \(F(x)=popcount(x) \mod 2\)
\(f_i=\sum\limits_{F(i\ \otimes j)=0}a_i-\sum\limits_{F(i\ \otimes j)=1}a_i\),\(g_i=\sum\limits_{F(i\ \otimes j)=0}a_i-\sum\limits_{F(i\ \otimes j)=1}b_i\)
可以暴力分类讨论得出 \(h_i=f_i \times g_i\) 。
实现方式
因为这类变换 位之间都是独立的,我们考虑类似于 \(DP\) 那样从低到高一位一位地去实现。
void fwtor(ll *f,int op){
for(ri len = 2,h = 1;len <= lim;len <<= 1,h <<= 1)
for(ri i = 0;i < lim;i += len)
for(ri j = i;j < i + h;++j)
f[j + h] += f[j] * op,f[j + h] %= mod;
}
void fwtand(ll *f,int op){
for(ri len = 2,h = 1;len <= lim;len <<= 1,h <<= 1)
for(ri i = 0;i < lim;i += len)
for(ri j = i;j < i + h;++j)
f[j] += f[j + h] * op,f[j] %= mod;
}
void fwtxor(ll *f,int op){
for(ri len = 2,h = 1;len <= lim;len <<= 1,h <<= 1)
for(ri i = 0;i < lim;i += len)
for(ri j = i;j < i + h;++j){
f[j] = (f[j] + f[j+h]) % mod;
f[j+h] = ((f[j] - f[j+h] - f[j+h]) % mod + mod) % mod;
f[j] = f[j] * op % mod;
f[j+h] = f[j+h] * op % mod;
}
}
子集卷积
\(c_i=\sum\limits_{j\ | \ k=i\ ,\ j \ \&\ k = i}a_j\times b_k\)
将条件转换,\(\sum\limits_{j\ | \ k=i\ ,\ F(j) + F(k) = F(j|k)}\) , \(F\) 定义同上
按 \(F(i)\) 将 \(a\) 分类 ,\(a'_{F(i),i} = a_i\) ,把 \(a',b'\) 求出 \(\text{fwt}\) ,然后手动卷那个 \(F(i)\) ,最后把 \(c'\) 的 \(\text{fwt}\) 搞回 \(c'\) ,\(c_i=c'_{F(i),i}\)
//from 2022.2.3 11:40
#include<bits/stdc++.h>
#define ri register int
#define ll long long
using namespace std;
const int maxn = (1<<20) + 5,mod = 1e9 + 9;
inline int rd(){
int res = 0,f = 0; char ch = getchar();
for(;!isdigit(ch);ch = getchar()) if(ch == '-') f = 1;
for(;isdigit(ch);ch = getchar()) res = (res<<3) + (res<<1) + ch - 48;
return f ? -res : res;
}
int n;
inline void fwt(ll *f,int op){//or 卷积
for(ri mid = 1;mid < (1<<n);mid <<= 1)
for(ri l = 0,len = (mid<<1);l < (1<<n);l += len)
for(ri i = 0;i < mid;++i)
if(op == 1) f[l + mid + i] = (f[l + mid + i] + f[l + i]) % mod;
else f[l + mid + i] = (f[l + mid + i] - f[l + i] + mod) % mod;
}
int pop[maxn];
ll f[21][maxn],g[21][maxn],h[21][maxn];
int main(){
n = rd();
for(ri i = 1;i < (1<<n);++i) pop[i] = pop[i>>1] + (i&1);
for(ri i = 0;i < (1<<n);++i) f[pop[i]][i] = rd();
for(ri i = 0;i < (1<<n);++i) g[pop[i]][i] = rd();
for(ri i = 0;i <= n;++i) fwt(f[i],1),fwt(g[i],1);
for(ri i = 0;i <= n;++i)
for(ri j = 0;j <= i;++j)
for(ri k = 0;k < (1<<n);++k)
h[i][k] = (h[i][k] + f[j][k] * g[i-j][k] % mod) % mod;
for(ri i = 0;i <= n;++i) fwt(h[i],-1);
for(ri i = 0;i < (1<<n);++i) printf("%lld ",h[pop[i]][i]);
puts("");
return 0;
}