suxxsfe

一言(ヒトコト)

FWT 笔记

用来解决在下标中进行位运算的卷积
具体形式就是求

\[c_i=\sum_{i=j\oplus k} a_j\cdot b_j \]

思路大概就是把序列 \(a\) 变换为 \(fwt(a)\)\(b,c\) 同理,使得 \(fwt(c)=fwt(a) fwt(b)\),这样得到了 \(fwt(c)\) 再变换回来

或运算

\[c_i=\sum_{i=j|k} a_j\cdot b_j \]

构造 \(fwt(a)_i=\sum_{i|j=i} a_j\)
也就是,取所有下标是 \(i\) 的二进制位中子集的 \(a_j\) 的和
由于有 \(i|j=i,i|k=i\Rightarrow i|(j|k)=i\),所以:

\[fwt(a)_i\cdot fwt(b)_i=\left(\sum_{i=i|j}a_j\right)\cdot \left(\sum_{i=i|k}b_k \right)=\sum_{i=i|j}\sum_{i=i|k} a_j b_k=\sum_{i=i|(j|k)} a_j b_k=fwt(c)_i \]

于是考虑如何求 \(fwt(a)\),考虑分治,用 \(a_0,a_1\) 分别表示 \(a\) 序列中下标第一个二进制位为 \(0/1\) 的情况(各 \(2^{n-1}\) 个数),则有:

\[fwt(a)=\operatorname{merge}(fwt(a_0),fwt(a_0+a_1)) \]

\(\operatorname{merge}\) 为链接,加号为每一位分别相加
就是说由于是求下标是它子集的元素的和,那么 \(a_1\) 是可以将第一个二进制位改为 \(0\),得到它的一个子集,也就是包含了 \(a_0\)

再考虑如何由 \(fwt(a)\) 求出 \(a\),其实直接反过来就好:

\[a=\operatorname{merge}(a_0,a_1-a_0) \]

与运算

\[c_i=\sum_{i=j\& k} a_j\cdot b_j \]

其实和或运算类似,构造 \(fwt(a)_i=\sum_{i\&j=i} a_j\)
然后分治的时候,可以把 \(a_0\) 的第一个二进制位改为 \(1\),包含上 \(a_1\),于是:

\[fwt(a)=\operatorname{merge}(fwt(a_0)+fwt(a_1),fwt(a_1)) \]

\[a=\operatorname{merge}(a_0-a_1,a_1) \]

异或

\[c_i=\sum_{i=j\operatorname{xor}k} a_j\cdot b_j \]

稍微复杂一些,不能用子集的关系表示了
\(f(i,j)=\operatorname{popcount}(i\&j) \bmod 2\)

有:\(f(i,j)\operatorname{xor}f(i,k)=f(i,j\operatorname{xor}k)\)
证明大概就是,因为是先与运算再统计二进制中 \(1\) 的个数,所以只用考虑 \(i\)\(1\) 的那几位,如果 \(j,k\) 在这些位上也是 \(1\) 的个数的奇偶性相同,那么他们中有一部分是重叠的会被异或掉,剩下的显然奇偶性仍然相同,那么总共偶数个,结果为 \(0\)
如果奇偶性不同,那么重叠的部分被异或掉以后,剩下的奇偶性仍然不同,总共奇数个,结果为 \(1\)

那么此时就可以构造:

\[fwt(a)=\sum_{f(i,j)=0} a_j-\sum_{f(i,j)=1} a_j \]

那么相乘就是

\[\begin{aligned}fwt(a)_i\cdot fwt(b)_i &=\left(\sum_{f(i,j)=0} a_j-\sum_{f(i,j)=1} a_j\right)\left(\sum_{f(i,k)=0} b_k-\sum_{f(i,k)=1} b_k\right)\\ &=\sum_{f(i,j)=0} a_j \sum_{f(i,k)=0} b_k-\sum_{f(i,j)=1} a_j\sum_{f(i,k)=0} b_k-\sum_{f(i,j)=0} a_j\sum_{f(i,k)=1} b_k+\sum_{f(i,j)=1} a_j\sum_{f(i,k)=1} b_k\\ &=\left(\sum_{f(i,j)=0}\sum_{f(i,k)=0} a_jb_k+\sum_{f(i,j)=1}\sum_{f(i,k)=1} a_jb_k\right)-\left(\sum_{f(i,j)=1}\sum_{f(i,k)=0} a_jb_k+\sum_{f(i,j)=0}\sum_{f(i,k)=1} a_jb_k\right)\\ &=\left(\sum_{f(i,j\operatorname{xor}k)=0\operatorname{xor}0} a_jb_k+\sum_{f(i,j\operatorname{xor}k)=1\operatorname{xor}1} a_jb_k\right)-\left(\sum_{f(i,j\operatorname{xor}k)=1\operatorname{xor}0} a_jb_k+\sum_{f(i,j\operatorname{xor}k)=0\operatorname{xor}1} a_jb_k\right)\\ &=\sum_{f(i,j\operatorname{xor}k)=0} a_jb_k-\sum_{f(i,j\operatorname{xor}k)=1} a_jb_k\\ &=fwt(c)_i\\ \end{aligned} \]

然后考虑分治的时候如何计算,有:

\[fwt(a)=\operatorname{merge}(a_0+a_1,a_0-a_1) \]

\[a=\operatorname{merge}(\frac{a_0+a_1}{2},\frac{a_0-a_1}{2}) \]

原理上,就是对于前 \(2^{n-1}\) 个数,最高位是 \(0\),由于 \(0\&0=0\&1=0\),对 \(f\) 的结果没有影响,直接简单相加
\(2^{n-1}\) 个数,最高位是 \(1\),由于 \(1\&0=0,1\&1=1\),使得 \(f\) 结果改变,应为 \(-a_1\)


模板题:https://www.luogu.com.cn/problem/P4717

#include<cstdio>
#include<cmath>
#include<algorithm>
#include<cstring>
#define reg register
#define LL_INF (long long)(0x3f3f3f3f3f3f3f3f)
#define INT_INF (int)(0x3f3f3f3f)
inline int read(){
	register int x=0;register int y=1;
	register char c=std::getchar();
	while(c<'0'||c>'9'){if(c=='-') y=0;c=getchar();}
	while(c>='0'&&c<='9'){x=x*10+(c^48);c=getchar();}
	return y?x:-x;
}
#define mod 998244353
#define inv 499122177
#define N 131078
int n;
long long A[N],B[N],C[N];
long long a[N],b[N],c[N];
inline void OR(long long *f,long long *a,int x){
	for(reg int i=0;i<n;i++) f[i]=a[i];
	for(reg int o=2,k=1;o<=n;o<<=1,k<<=1)
		for(reg int i=0;i<n;i+=o)for(reg int j=0;j<k;j++)
			f[i+j+k]=(f[i+j+k]+mod+(x?f[i+j]:-f[i+j]))%mod;
}
inline void AND(long long *f,long long *a,int x){
	for(reg int i=0;i<n;i++) f[i]=a[i];
	for(reg int o=2,k=1;o<=n;o<<=1,k<<=1)
		for(reg int i=0;i<n;i+=o)for(reg int j=0;j<k;j++)
			f[i+j]=(f[i+j]+mod+(x?f[i+j+k]:-f[i+j+k]))%mod;
}
inline void XOR(long long *f,long long *a,int x){
	for(reg int i=0;i<n;i++) f[i]=a[i];
	for(reg int o=2,k=1;o<=n;o<<=1,k<<=1)
		for(reg int i=0;i<n;i+=o)for(reg int j=0;j<k;j++){
			f[i+j]=(f[i+j]+f[i+j+k])%mod;
			f[i+j+k]=(f[i+j]-f[i+j+k]-f[i+j+k]+mod+mod)%mod;
			if(!x) f[i+j]=f[i+j]*inv%mod,f[i+j+k]=f[i+j+k]*inv%mod;
		}
}
inline void calc(long long *a,long long *b,long long *c){
	for(reg int i=0;i<n;i++) c[i]=a[i]*b[i]%mod;
}
int main(){
	n=(1<<read());
	for(reg int i=0;i<n;i++) A[i]=read();
	for(reg int i=0;i<n;i++) B[i]=read();
	OR(a,A,1);OR(b,B,1);calc(a,b,c);OR(C,c,0);
	for(reg int i=0;i<n;i++) printf("%d ",C[i]);puts("");
	AND(a,A,1);AND(b,B,1);calc(a,b,c);AND(C,c,0);
	for(reg int i=0;i<n;i++) printf("%d ",C[i]);puts("");
	XOR(a,A,1);XOR(b,B,1);calc(a,b,c);XOR(C,c,0);
	for(reg int i=0;i<n;i++) printf("%d ",C[i]);puts("");
	return 0;
}
posted @ 2021-03-11 21:17  suxxsfe  阅读(72)  评论(0编辑  收藏  举报