FWT 学习笔记
更新于 2023.2.20,之前写的很粗糙,补一下。
讲一下原理,类似 DFT 的,我们也希望构造一个 FWT 数组使得卷积可以写作变幻、点积、变幻。
对于 or 卷积,构造 \(FWT(F_i)=\sum\limits_{j\subseteq i} F_j\),对于 and 卷积构造 \(FWT(F_i)=\sum\limits_{i\subseteq j}F_j\)。
可以递归求,以 or 卷积为例:
可以发现我们循环实现的就是这个东西,我们只需要构造 \(FWT\) 的方法,然后对两部做变幻。
异或卷积比较特殊,有 \(FWT(F_i)=\sum\limits_{j}(-1)^{|i\cap j|} F_j\),变幻是容易构造的。
类似 FFT 地,FWT 也有这样的代码实现:
inline int add(int x){return x>=mod?x-mod:x;}
inline int sub(int x){return x<0?x+mod:x;}
inline void fwt(int *f,int len,int flg){
for(int i=1;i<len;i<<=1)
for(int j=0;j<len;j+=i+i)
for(int k=j;k<j+i;k++){
int u=f[k],v=f[k+i];
f[k]=add(u+v);f[k+i]=sub(u-v);
if(!flg){
f[k]=1ll*f[k]*iv2%mod;
f[k+i]=1ll*f[k+i]*iv2%mod;
}
}
}
相比于 FMT,FWT 有一半的常数并且可以执行异或操作。
uoj310. 黎明前的巧克力
题意:给定集合 \(S\),计数 \(S1,S2,S3\) 互不相交,并为 \(S\),且 \(\oplus S3=0\) 的方案数。
题解:显然有 \(dp\):\(f_{i,j}=f_{i-1,j},f_{i,j}=2f_{i-1,j\oplus{a_i}}\),令 \(F_{i,0}=1,F_{i,a_i}=2\),则答案为 \(FWT(\prod F_i)\)
考察 FWT 的性质,首先 FWT 是线性变换,可以把 \(0\) 和 \(a_i\) 拆开算。
类似 FFT 地,FWT 相当于给每一维做 FFT,于是 \(FWT(F)(n)=\sum\limits_{k=0}^{2^l-1} F_k\prod\limits_{j=0}^l (-1)^{n_jk_j}=\sum\limits_{k=0}^{2^l-1} F_k(-1)^{n\&k}\)
于是显然 \(F_0=1\) 的 FWT 就是每一位都是 \(1\),\(F_{a_i}=2\) 的 FWT 就是一堆 \(2\) 和 \(-2\)
于是我们只用知道最终有多少个 \(-1\) 和 \(3\) 就好了,显然我们知道和就能解出来,而和就是整体的 FWT
#include<bits/stdc++.h>
using namespace std;
#define inf 1e9
const int maxn=2e5+10;
const int mod=998244353;
const int iv2=(mod+1)/2;
inline int read(){
int x=0,f=1;char c=getchar();
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9'){x=(x<<1)+(x<<3)+c-'0';c=getchar();}
return x*f;
}
inline int add(int x){return x>=mod?x-mod:x;}
inline int sub(int x){return x<0?x+mod:x;}
inline void fwt(int *f,int len,int flg){
for(int i=1;i<len;i<<=1)
for(int j=0;j<len;j+=i+i)
for(int k=j;k<j+i;k++){
int u=f[k],v=f[k+i];
f[k]=u+v;f[k+i]=u-v;
if(!flg){
f[k]=1ll*f[k]*iv2%mod;
f[k+i]=1ll*f[k+i]*iv2%mod;
f[k]=sub(f[k]);
f[k+i]=sub(f[k+i]);
}
}
}
const int len=1<<20;
int n,m,a[len+5];
inline int ksm(int x,int y){
int res=1;
while(y){
if(y&1)res=1ll*res*x%mod;
x=1ll*x*x%mod;y>>=1;
}return res;
}
int main(){
n=read();
for(int i=1,x;i<=n;i++)
x=read(),a[x]++;
fwt(a,len,1);
for(int i=0;i<len;i++)
a[i]=1ll*ksm(3,(a[i]+n)/2)*ksm(mod-1,(n-a[i])/2)%mod;
fwt(a,len,0);
printf("%d\n",sub(a[0]-1));
return 0;
}
CF662C Binary Table
首先很容易想到 \(O(2^n)\) 枚举翻转行集合,然后 \(O(m)\) 判断 \(0/1\) 最小值。
令 \(cnt_i\) 表示数值个数,\(a_i\) 表示状态 \(i\) 的 \(0/1\) 最小值,则翻转集合为 \(S\) 的答案为 \(F_S=\sum cnt_i a_{i\oplus S}\),显然的 FWT 形式。
#include<bits/stdc++.h>
using namespace std;
#define inf 1e16
const int maxn=2e5+10;
const int mod=1e9+7;
inline int read(){
int x=0,f=1;char c=getchar();
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9'){x=(x<<1)+(x<<3)+c-'0';c=getchar();}
return x*f;
}
#define ll long long
const int N=1<<20;
int n,m,lim;
ll a[N+5],b[N+5];
inline void fwt(ll *f,int len,int flg){
for(int i=2;i<=len;i<<=1)
for(int j=0,p=i/2;j<len;j+=i)
for(int k=j;k<j+p;k++){
ll u=f[k],v=f[k+p];
f[k]=u+v;f[k+p]=u-v;
if(!flg)f[k]/=2,f[k+p]/=2;
}
}
int main(){
n=read(),m=read(),lim=1<<n;
for(int i=0;i<n;i++)
for(int j=1,x;j<=m;j++)
scanf("%1d",&x),a[j]=a[j]+(x<<i);
for(int i=1;i<=m;i++)b[a[i]]++;
for(int i=1;i<lim;i++)a[i]=a[i&(i-1)]+1;
for(int i=1;i<lim;i++)a[i]=min(a[i],n-a[i]);
fwt(a,lim,1),fwt(b,lim,1);
for(int i=0;i<lim;i++)a[i]*=b[i];
fwt(a,lim,0);ll Mn=inf;
for(int i=0;i<lim;i++)Mn=min(Mn,a[i]);
printf("%lld\n",Mn);
return 0;
}
CF850E Random Elections
太着急看题解了,下次可以固定一个时间强制思考。
三人等价,不妨假设 \(A\) 赢了 \(B,C\),将 \(AB,AC,BC\) 也用 \(0/1\) 表示,则有 \(AB,AC\) 为 \(00\) 或 \(11\) 时有一种方案,否则两种。
则我们将 \(F\) 跟自己卷一遍,然后每个位置乘上二的 \(0\) 个数次方加入答案即可。
#include<bits/stdc++.h>
using namespace std;
inline int read(){
int x=0,f=1;char c=getchar();
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9'){x=(x<<1)+(x<<3)+c-'0';c=getchar();}
return x*f;
}
const int maxn=2e5+10;
const int mod=1e9+7;
const int iv2=(mod+1)/2;
inline int ksm(int x,int y){
int res=1;
while(y){
if(y&1)res=1ll*res*x%mod;
x=1ll*x*x%mod;y>>=1;
}return res;
}
inline int add(int x){return x>=mod?x-mod:x;}
inline int sub(int x){return x<0?x+mod:x;}
inline void fwt(int *f,int len,int flg){
for(int i=2;i<=len;i<<=1)
for(int j=0,p=i/2;j<len;j+=i)
for(int k=j;k<j+p;k++){
int u=f[k],v=f[k+p];
f[k]=add(u+v);
f[k+p]=sub(u-v);
if(!flg){
f[k]=1ll*f[k]*iv2%mod;
f[k+p]=1ll*f[k+p]*iv2%mod;
}
}
}
const int N=1<<20;
int f[N+5],lim,ans,n;
int main(){
n=read();lim=1<<n;
for(int i=0;i<lim;i++)
scanf("%1d",&f[i]);
fwt(f,lim,1);
for(int i=0;i<lim;i++)
f[i]=1ll*f[i]*f[i]%mod;
fwt(f,lim,0);
for(int i=0;i<lim;i++)
ans=(ans+1ll*f[i]*(1<<n-__builtin_popcount(i)))%mod;
printf("%d\n",3ll*ans%mod);
return 0;
}
WC2018 州区划分
首先可以得到一个简单的 \(3^n\) dp,轻松拿到 \(n\leqslant 15\) 的部分分。
发现瓶颈在于解决形如 \(f_i=\sum\limits_{j|k=i,j\&k=0} f_jg_k\),直接上子集卷积即可。
#include<bits/stdc++.h>
using namespace std;
#define inf 1e9
const int maxn=2e5+10;
const int mod=998244353;
const int iv2=(mod+1)/2;
inline int read(){
int x=0,f=1;char c=getchar();
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9'){x=(x<<1)+(x<<3)+c-'0';c=getchar();}
return x*f;
}
inline int add(int x){return x>=mod?x-mod:x;}
inline int sub(int x){return x<0?x+mod:x;}
inline void fwt(int *f,int len,int flg){
for(int i=2;i<=len;i<<=1)
for(int j=0,l=i/2;j<len;j+=i)
for(int k=j;k<j+l;k++){
int u=f[k],v=f[k+l];
if(flg)f[k+l]=add(u+v);
else f[k+l]=sub(v-u);
}
}
const int N=21;
int n,m,p,lk[1<<N],w[1<<N],ppc[1<<N],con[1<<N],ok[1<<N],f[1<<N],iw[1<<N],dp[N+1][1<<N],a[N+1][1<<N];
inline int ksm(int x,int y){
int res=1;
while(y){
if(y&1)res=1ll*res*x%mod;
x=1ll*x*x%mod;y>>=1;
}return res;
}
inline int calc(int j,int i){
int v=1ll*w[j]*iw[i]%mod;
return ksm(v,p);
}
vector<int>G[N];
#define pb push_back
int vis[N],lg[1<<N];
inline void dfs(int sta,int x){
if(vis[x])return;vis[x]=1;
for(auto t:G[x])
if(sta>>t&1)dfs(sta,t);
}
int main(){
freopen("P4221.in","r",stdin);
freopen("P4221.out","w",stdout);
n=read(),m=read(),p=read();
for(int i=1,u,v;i<=m;i++){
u=read()-1,v=read()-1;
lk[1<<u]|=(1<<v);lk[1<<v]|=(1<<u);
G[u].pb(v),G[v].pb(u);
}
for(int i=0;i<n;i++)w[1<<i]=read(),lg[1<<i]=i;
for(int i=0;i<(1<<n);i++){
ppc[i]=ppc[i>>1]+(i&1);
if(ppc[i]==1)con[i]=1,ok[i]=1;
else{
w[i]=(w[i&(i-1)]+w[i&(-i)])%mod;
lk[i]=lk[i&(i-1)]|lk[i&(-i)];
con[i]=1;memset(vis,0,sizeof(vis));
dfs(i,lg[i&(-i)]);
for(int j=0;j<n;j++)
if(i>>j&1)con[i]&=vis[j];
ok[i]=con[i];
for(int j=0;j<n;j++)
if((i>>j&1)&&(ppc[lk[1<<j]&i]&1))ok[i]=0;
}iw[i]=ksm(w[i],mod-2);ok[i]^=1;
}
if(n<=15){
for(int i=1;i<(1<<n);i++){
if(ok[i])f[i]=1;
for(int j=(i-1)&i;j;j=(j-1)&i)
if(ok[j])f[i]=(f[i]+1ll*f[i^j]*calc(j,i))%mod;
}printf("%d\n",f[(1<<n)-1]);
}else{
int lim=(1<<n);
dp[0][0]=1;fwt(dp[0],lim,1);
for(int i=1;i<=n;i++){
for(int j=0;j<(1<<n);j++)
if(ppc[j]==i)a[i][j]=1ll*ok[j]*ksm(w[j],p)%mod;
fwt(a[i],lim,1);
for(int j=1;j<=i;j++)
for(int k=0;k<(1<<n);k++)
dp[i][k]=(dp[i][k]+1ll*dp[i-j][k]*a[j][k])%mod;
fwt(dp[i],lim,0);
for(int j=0;j<(1<<n);j++)
if(ppc[j]!=i)dp[i][j]=0;
else dp[i][j]=1ll*dp[i][j]*ksm(iw[j],p)%mod;
fwt(dp[i],lim,1);
}fwt(dp[n],lim,0);
printf("%d\n",dp[n][lim-1]);
}
return 0;
}
fun fact: 有 \((-1)^{|a\cap b|+|a\cap c|}=(-1)^{a\cap (b\oplus c)}\)