FWT 学习笔记

更新于 2023.2.20,之前写的很粗糙,补一下。


讲一下原理,类似 DFT 的,我们也希望构造一个 FWT 数组使得卷积可以写作变幻、点积、变幻。

对于 or 卷积,构造 \(FWT(F_i)=\sum\limits_{j\subseteq i} F_j\),对于 and 卷积构造 \(FWT(F_i)=\sum\limits_{i\subseteq j}F_j\)

可以递归求,以 or 卷积为例:

\[FWT(F)=merge(FWT(F_0),FWT(F_0)+FWT(F_1))\\ IFWT(F)=merge(IFWT(F_0),IFWT(F_1)-IFWT(F_0)) \]

可以发现我们循环实现的就是这个东西,我们只需要构造 \(FWT\) 的方法,然后对两部做变幻。

异或卷积比较特殊,有 \(FWT(F_i)=\sum\limits_{j}(-1)^{|i\cap j|} F_j\),变幻是容易构造的。

类似 FFT 地,FWT 也有这样的代码实现:

inline int add(int x){return x>=mod?x-mod:x;}
inline int sub(int x){return x<0?x+mod:x;}
inline void fwt(int *f,int len,int flg){
	for(int i=1;i<len;i<<=1)
		for(int j=0;j<len;j+=i+i)
			for(int k=j;k<j+i;k++){
				int u=f[k],v=f[k+i];
				f[k]=add(u+v);f[k+i]=sub(u-v);
				if(!flg){
					f[k]=1ll*f[k]*iv2%mod;
					f[k+i]=1ll*f[k+i]*iv2%mod;
				}
			}
}

相比于 FMT,FWT 有一半的常数并且可以执行异或操作。

uoj310. 黎明前的巧克力

题意:给定集合 \(S\),计数 \(S1,S2,S3\) 互不相交,并为 \(S\),且 \(\oplus S3=0\) 的方案数。

题解:显然有 \(dp\)\(f_{i,j}=f_{i-1,j},f_{i,j}=2f_{i-1,j\oplus{a_i}}\),令 \(F_{i,0}=1,F_{i,a_i}=2\),则答案为 \(FWT(\prod F_i)\)

考察 FWT 的性质,首先 FWT 是线性变换,可以把 \(0\)\(a_i\) 拆开算。

类似 FFT 地,FWT 相当于给每一维做 FFT,于是 \(FWT(F)(n)=\sum\limits_{k=0}^{2^l-1} F_k\prod\limits_{j=0}^l (-1)^{n_jk_j}=\sum\limits_{k=0}^{2^l-1} F_k(-1)^{n\&k}\)

于是显然 \(F_0=1\) 的 FWT 就是每一位都是 \(1\)\(F_{a_i}=2\) 的 FWT 就是一堆 \(2\)\(-2\)

于是我们只用知道最终有多少个 \(-1\)\(3\) 就好了,显然我们知道和就能解出来,而和就是整体的 FWT

#include<bits/stdc++.h>
using namespace std;
#define inf 1e9
const int maxn=2e5+10;
const int mod=998244353;
const int iv2=(mod+1)/2;
inline int read(){
	int x=0,f=1;char c=getchar();
	while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
	while(c>='0'&&c<='9'){x=(x<<1)+(x<<3)+c-'0';c=getchar();}
	return x*f;
}
inline int add(int x){return x>=mod?x-mod:x;}
inline int sub(int x){return x<0?x+mod:x;}
inline void fwt(int *f,int len,int flg){
	for(int i=1;i<len;i<<=1)
		for(int j=0;j<len;j+=i+i)
			for(int k=j;k<j+i;k++){
				int u=f[k],v=f[k+i];
				f[k]=u+v;f[k+i]=u-v;
				if(!flg){
					f[k]=1ll*f[k]*iv2%mod;
					f[k+i]=1ll*f[k+i]*iv2%mod;
					f[k]=sub(f[k]);
					f[k+i]=sub(f[k+i]);
				}
			}
}
const int len=1<<20;
int n,m,a[len+5];
inline int ksm(int x,int y){
	int res=1;
	while(y){
		if(y&1)res=1ll*res*x%mod;
		x=1ll*x*x%mod;y>>=1;
	}return res;
}
int main(){
	n=read();
	for(int i=1,x;i<=n;i++)
		x=read(),a[x]++;
	fwt(a,len,1);
	for(int i=0;i<len;i++)
		a[i]=1ll*ksm(3,(a[i]+n)/2)*ksm(mod-1,(n-a[i])/2)%mod;
	fwt(a,len,0);
	printf("%d\n",sub(a[0]-1));
	return 0;
}

CF662C Binary Table

首先很容易想到 \(O(2^n)\) 枚举翻转行集合,然后 \(O(m)\) 判断 \(0/1\) 最小值。

\(cnt_i\) 表示数值个数,\(a_i\) 表示状态 \(i\)\(0/1\) 最小值,则翻转集合为 \(S\) 的答案为 \(F_S=\sum cnt_i a_{i\oplus S}\),显然的 FWT 形式。

#include<bits/stdc++.h>
using namespace std;
#define inf 1e16
const int maxn=2e5+10;
const int mod=1e9+7;
inline int read(){
	int x=0,f=1;char c=getchar();
	while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
	while(c>='0'&&c<='9'){x=(x<<1)+(x<<3)+c-'0';c=getchar();}
	return x*f;
}
#define ll long long
const int N=1<<20;
int n,m,lim;
ll a[N+5],b[N+5];
inline void fwt(ll *f,int len,int flg){
	for(int i=2;i<=len;i<<=1)
		for(int j=0,p=i/2;j<len;j+=i)
			for(int k=j;k<j+p;k++){
				ll u=f[k],v=f[k+p];
				f[k]=u+v;f[k+p]=u-v;
				if(!flg)f[k]/=2,f[k+p]/=2;
			}
}
int main(){
	n=read(),m=read(),lim=1<<n;
	for(int i=0;i<n;i++)
		for(int j=1,x;j<=m;j++)
			scanf("%1d",&x),a[j]=a[j]+(x<<i);
	for(int i=1;i<=m;i++)b[a[i]]++;
	for(int i=1;i<lim;i++)a[i]=a[i&(i-1)]+1;
	for(int i=1;i<lim;i++)a[i]=min(a[i],n-a[i]);
	fwt(a,lim,1),fwt(b,lim,1);
	for(int i=0;i<lim;i++)a[i]*=b[i];
	fwt(a,lim,0);ll Mn=inf;
	for(int i=0;i<lim;i++)Mn=min(Mn,a[i]);
	printf("%lld\n",Mn);
	return 0;
}

CF850E Random Elections

太着急看题解了,下次可以固定一个时间强制思考。

三人等价,不妨假设 \(A\) 赢了 \(B,C\),将 \(AB,AC,BC\) 也用 \(0/1\) 表示,则有 \(AB,AC\)\(00\)\(11\) 时有一种方案,否则两种。

则我们将 \(F\) 跟自己卷一遍,然后每个位置乘上二的 \(0\) 个数次方加入答案即可。

#include<bits/stdc++.h>
using namespace std;
inline int read(){
	int x=0,f=1;char c=getchar();
	while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
	while(c>='0'&&c<='9'){x=(x<<1)+(x<<3)+c-'0';c=getchar();}
	return x*f;
}
const int maxn=2e5+10;
const int mod=1e9+7;
const int iv2=(mod+1)/2;
inline int ksm(int x,int y){
	int res=1;
	while(y){
		if(y&1)res=1ll*res*x%mod;
		x=1ll*x*x%mod;y>>=1;
	}return res;
}
inline int add(int x){return x>=mod?x-mod:x;}
inline int sub(int x){return x<0?x+mod:x;}
inline void fwt(int *f,int len,int flg){
	for(int i=2;i<=len;i<<=1)
		for(int j=0,p=i/2;j<len;j+=i)
			for(int k=j;k<j+p;k++){
				int u=f[k],v=f[k+p];
				f[k]=add(u+v);
				f[k+p]=sub(u-v);
				if(!flg){
					f[k]=1ll*f[k]*iv2%mod;
					f[k+p]=1ll*f[k+p]*iv2%mod;
				}
			}
}
const int N=1<<20;
int f[N+5],lim,ans,n;
int main(){
	n=read();lim=1<<n;
	for(int i=0;i<lim;i++)
		scanf("%1d",&f[i]);
	fwt(f,lim,1);
	for(int i=0;i<lim;i++)
		f[i]=1ll*f[i]*f[i]%mod;
	fwt(f,lim,0);
	for(int i=0;i<lim;i++)
		ans=(ans+1ll*f[i]*(1<<n-__builtin_popcount(i)))%mod;
	printf("%d\n",3ll*ans%mod);
	return 0;
}

WC2018 州区划分

首先可以得到一个简单的 \(3^n\) dp,轻松拿到 \(n\leqslant 15\) 的部分分。

发现瓶颈在于解决形如 \(f_i=\sum\limits_{j|k=i,j\&k=0} f_jg_k\),直接上子集卷积即可。

#include<bits/stdc++.h>
using namespace std;
#define inf 1e9
const int maxn=2e5+10;
const int mod=998244353;
const int iv2=(mod+1)/2;
inline int read(){
	int x=0,f=1;char c=getchar();
	while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
	while(c>='0'&&c<='9'){x=(x<<1)+(x<<3)+c-'0';c=getchar();}
	return x*f;
}
inline int add(int x){return x>=mod?x-mod:x;}
inline int sub(int x){return x<0?x+mod:x;}
inline void fwt(int *f,int len,int flg){
	for(int i=2;i<=len;i<<=1)
		for(int j=0,l=i/2;j<len;j+=i)
			for(int k=j;k<j+l;k++){
				int u=f[k],v=f[k+l];
				if(flg)f[k+l]=add(u+v);
				else f[k+l]=sub(v-u);
			}
}
const int N=21;
int n,m,p,lk[1<<N],w[1<<N],ppc[1<<N],con[1<<N],ok[1<<N],f[1<<N],iw[1<<N],dp[N+1][1<<N],a[N+1][1<<N];
inline int ksm(int x,int y){
	int res=1;
	while(y){
		if(y&1)res=1ll*res*x%mod;
		x=1ll*x*x%mod;y>>=1;
	}return res;
}
inline int calc(int j,int i){
	int v=1ll*w[j]*iw[i]%mod;
	return ksm(v,p);
}
vector<int>G[N];
#define pb push_back
int vis[N],lg[1<<N];
inline void dfs(int sta,int x){
	if(vis[x])return;vis[x]=1;
	for(auto t:G[x])
		if(sta>>t&1)dfs(sta,t);
}
int main(){
	freopen("P4221.in","r",stdin);
	freopen("P4221.out","w",stdout);
	n=read(),m=read(),p=read();
	for(int i=1,u,v;i<=m;i++){
		u=read()-1,v=read()-1;
		lk[1<<u]|=(1<<v);lk[1<<v]|=(1<<u);
		G[u].pb(v),G[v].pb(u);
	}
	for(int i=0;i<n;i++)w[1<<i]=read(),lg[1<<i]=i;
	for(int i=0;i<(1<<n);i++){
		ppc[i]=ppc[i>>1]+(i&1);
		if(ppc[i]==1)con[i]=1,ok[i]=1;
		else{
			w[i]=(w[i&(i-1)]+w[i&(-i)])%mod;
			lk[i]=lk[i&(i-1)]|lk[i&(-i)];
			con[i]=1;memset(vis,0,sizeof(vis));
			dfs(i,lg[i&(-i)]);
			for(int j=0;j<n;j++)
				if(i>>j&1)con[i]&=vis[j];
			ok[i]=con[i];
			for(int j=0;j<n;j++)
				if((i>>j&1)&&(ppc[lk[1<<j]&i]&1))ok[i]=0;
		}iw[i]=ksm(w[i],mod-2);ok[i]^=1;
	}
	if(n<=15){
		for(int i=1;i<(1<<n);i++){
			if(ok[i])f[i]=1;
			for(int j=(i-1)&i;j;j=(j-1)&i)
				if(ok[j])f[i]=(f[i]+1ll*f[i^j]*calc(j,i))%mod;
		}printf("%d\n",f[(1<<n)-1]);
	}else{
		int lim=(1<<n);
		dp[0][0]=1;fwt(dp[0],lim,1);
		for(int i=1;i<=n;i++){
			for(int j=0;j<(1<<n);j++)
				if(ppc[j]==i)a[i][j]=1ll*ok[j]*ksm(w[j],p)%mod;
			fwt(a[i],lim,1);
			for(int j=1;j<=i;j++)
				for(int k=0;k<(1<<n);k++)
					dp[i][k]=(dp[i][k]+1ll*dp[i-j][k]*a[j][k])%mod;
			fwt(dp[i],lim,0);
			for(int j=0;j<(1<<n);j++)
				if(ppc[j]!=i)dp[i][j]=0;
				else dp[i][j]=1ll*dp[i][j]*ksm(iw[j],p)%mod;
			fwt(dp[i],lim,1);
		}fwt(dp[n],lim,0);
		printf("%d\n",dp[n][lim-1]);
	}
    return 0;
}

fun fact: 有 \((-1)^{|a\cap b|+|a\cap c|}=(-1)^{a\cap (b\oplus c)}\)

posted @ 2022-03-30 07:17  syzf2222  阅读(180)  评论(0编辑  收藏  举报