FWT 笔记
用来解决在下标中进行位运算的卷积
具体形式就是求
思路大概就是把序列 \(a\) 变换为 \(fwt(a)\),\(b,c\) 同理,使得 \(fwt(c)=fwt(a) fwt(b)\),这样得到了 \(fwt(c)\) 再变换回来
或运算
构造 \(fwt(a)_i=\sum_{i|j=i} a_j\)
也就是,取所有下标是 \(i\) 的二进制位中子集的 \(a_j\) 的和
由于有 \(i|j=i,i|k=i\Rightarrow i|(j|k)=i\),所以:
于是考虑如何求 \(fwt(a)\),考虑分治,用 \(a_0,a_1\) 分别表示 \(a\) 序列中下标第一个二进制位为 \(0/1\) 的情况(各 \(2^{n-1}\) 个数),则有:
\(\operatorname{merge}\) 为链接,加号为每一位分别相加
就是说由于是求下标是它子集的元素的和,那么 \(a_1\) 是可以将第一个二进制位改为 \(0\),得到它的一个子集,也就是包含了 \(a_0\)
再考虑如何由 \(fwt(a)\) 求出 \(a\),其实直接反过来就好:
与运算
其实和或运算类似,构造 \(fwt(a)_i=\sum_{i\&j=i} a_j\)
然后分治的时候,可以把 \(a_0\) 的第一个二进制位改为 \(1\),包含上 \(a_1\),于是:
异或
稍微复杂一些,不能用子集的关系表示了
设 \(f(i,j)=\operatorname{popcount}(i\&j) \bmod 2\)
有:\(f(i,j)\operatorname{xor}f(i,k)=f(i,j\operatorname{xor}k)\)
证明大概就是,因为是先与运算再统计二进制中 \(1\) 的个数,所以只用考虑 \(i\) 为 \(1\) 的那几位,如果 \(j,k\) 在这些位上也是 \(1\) 的个数的奇偶性相同,那么他们中有一部分是重叠的会被异或掉,剩下的显然奇偶性仍然相同,那么总共偶数个,结果为 \(0\)
如果奇偶性不同,那么重叠的部分被异或掉以后,剩下的奇偶性仍然不同,总共奇数个,结果为 \(1\)
那么此时就可以构造:
那么相乘就是
然后考虑分治的时候如何计算,有:
原理上,就是对于前 \(2^{n-1}\) 个数,最高位是 \(0\),由于 \(0\&0=0\&1=0\),对 \(f\) 的结果没有影响,直接简单相加
后 \(2^{n-1}\) 个数,最高位是 \(1\),由于 \(1\&0=0,1\&1=1\),使得 \(f\) 结果改变,应为 \(-a_1\)
模板题:https://www.luogu.com.cn/problem/P4717
#include<cstdio>
#include<cmath>
#include<algorithm>
#include<cstring>
#define reg register
#define LL_INF (long long)(0x3f3f3f3f3f3f3f3f)
#define INT_INF (int)(0x3f3f3f3f)
inline int read(){
register int x=0;register int y=1;
register char c=std::getchar();
while(c<'0'||c>'9'){if(c=='-') y=0;c=getchar();}
while(c>='0'&&c<='9'){x=x*10+(c^48);c=getchar();}
return y?x:-x;
}
#define mod 998244353
#define inv 499122177
#define N 131078
int n;
long long A[N],B[N],C[N];
long long a[N],b[N],c[N];
inline void OR(long long *f,long long *a,int x){
for(reg int i=0;i<n;i++) f[i]=a[i];
for(reg int o=2,k=1;o<=n;o<<=1,k<<=1)
for(reg int i=0;i<n;i+=o)for(reg int j=0;j<k;j++)
f[i+j+k]=(f[i+j+k]+mod+(x?f[i+j]:-f[i+j]))%mod;
}
inline void AND(long long *f,long long *a,int x){
for(reg int i=0;i<n;i++) f[i]=a[i];
for(reg int o=2,k=1;o<=n;o<<=1,k<<=1)
for(reg int i=0;i<n;i+=o)for(reg int j=0;j<k;j++)
f[i+j]=(f[i+j]+mod+(x?f[i+j+k]:-f[i+j+k]))%mod;
}
inline void XOR(long long *f,long long *a,int x){
for(reg int i=0;i<n;i++) f[i]=a[i];
for(reg int o=2,k=1;o<=n;o<<=1,k<<=1)
for(reg int i=0;i<n;i+=o)for(reg int j=0;j<k;j++){
f[i+j]=(f[i+j]+f[i+j+k])%mod;
f[i+j+k]=(f[i+j]-f[i+j+k]-f[i+j+k]+mod+mod)%mod;
if(!x) f[i+j]=f[i+j]*inv%mod,f[i+j+k]=f[i+j+k]*inv%mod;
}
}
inline void calc(long long *a,long long *b,long long *c){
for(reg int i=0;i<n;i++) c[i]=a[i]*b[i]%mod;
}
int main(){
n=(1<<read());
for(reg int i=0;i<n;i++) A[i]=read();
for(reg int i=0;i<n;i++) B[i]=read();
OR(a,A,1);OR(b,B,1);calc(a,b,c);OR(C,c,0);
for(reg int i=0;i<n;i++) printf("%d ",C[i]);puts("");
AND(a,A,1);AND(b,B,1);calc(a,b,c);AND(C,c,0);
for(reg int i=0;i<n;i++) printf("%d ",C[i]);puts("");
XOR(a,A,1);XOR(b,B,1);calc(a,b,c);XOR(C,c,0);
for(reg int i=0;i<n;i++) printf("%d ",C[i]);puts("");
return 0;
}