FWT 等总结 题解
FWT可以解决位运算卷积问题。
即\(h(i)=\sum\limits_{j⊕k=i} f(j)*g(k)\),其中“⊕”表示位运算。
与卷积:
定义\(f\)到\(F\)的变换:\(F(i)=\sum\limits_{j\&i==i}^{ }f(j)\)。
这样,若\(h(i)=\sum\limits_{j and k=i} f(j)*g(k)\),则\(H(i)=F(i)*G(i)\)。
变换方法:就是按照长度为\(2^i\)分段,把每段的后半部分加到前半部分(1对0有额外贡献)。
逆变换就是减回去。时间复杂度:\(O(nlogn)\)。
代码:
void fwtand(int sz[132000],int n)
{
for(int i=2;i<=n;i=(i<<1))
{
for(int j=0;j<n;j+=i)
{
for(int k=0;k<(i>>1);k++)
sz[j+k]=(sz[j+k]+sz[j+(i>>1)+k])%md;
}
}
}
void ifwtand(int sz[132000],int n)
{
for(int i=2;i<=n;i=(i<<1))
{
for(int j=0;j<n;j+=i)
{
for(int k=0;k<(i>>1);k++)
sz[j+k]=(sz[j+k]-sz[j+(i>>1)+k]+md)%md;
}
}
}
或卷积:
与“与卷积”类似。
定义\(f\)到\(F\)的变换:\(F(i)=\sum\limits_{j|i==i}^{ }f(j)\)。
这样,若\(h(i)=\sum\limits_{j or k=i} f(j)*g(k)\),则\(H(i)=F(i)*G(i)\)。
变换方法:就是按照长度为\(2^i\)分段,把每段的前半部分加到后半部分(0对1有额外贡献)。
逆变换就是减回去。时间复杂度:\(O(nlogn)\)。
代码:
void fwtor(int sz[132000],int n)
{
for(int i=2;i<=n;i=(i<<1))
{
for(int j=0;j<n;j+=i)
{
for(int k=0;k<(i>>1);k++)
sz[j+(i>>1)+k]=(sz[j+(i>>1)+k]+sz[j+k])%md;
}
}
}
void ifwtor(int sz[132000],int n)
{
for(int i=2;i<=n;i=(i<<1))
{
for(int j=0;j<n;j+=i)
{
for(int k=0;k<(i>>1);k++)
sz[j+(i>>1)+k]=(sz[j+(i>>1)+k]-sz[j+k]+md)%md;
}
}
}
这两个其实是高维前/后缀和
异或卷积:
这个比较常用。
定义\(f\)到\(F\)的变换:\(F(i)=\sum\limits_{j=0}^{2^n-1}(-1)^{bit(j and i)}f(j)\)。
这样,若\(h(i)=\sum\limits_{j xor k=i} f(j)*g(k)\),则\(H(i)=F(i)*G(i)\)。
变换方法:就是按照长度为\(2^i\)分段,把每段的前半部分变为前半部分加后半部分,
后半部分变为前半部分减后半部分。
逆变换就是相当于已知\(a+b=x,a-b=y\),则\(a=(x+y)/2,b=(x-y)/2\)。
就是正变换再除以2。
时间复杂度:\(O(nlogn)\)。
代码:
void fwtxor(int sz[132000],int n)
{
for(int i=2;i<=n;i=(i<<1))
{
for(int j=0;j<n;j+=i)
{
for(int k=0;k<(i>>1);k++)
{
int a=sz[j+k],b=sz[j+(i>>1)+k];
sz[j+k]=(a+b)%md;
sz[j+(i>>1)+k]=(a-b+md)%md;
}
}
}
}
void ifwtxor(int sz[132000],int n)
{
for(int i=2;i<=n;i=(i<<1))
{
for(int j=0;j<n;j+=i)
{
for(int k=0;k<(i>>1);k++)
{
int a=sz[j+k],b=sz[j+(i>>1)+k];
sz[j+k]=1ll*(a+b)*inv%md;
sz[j+(i>>1)+k]=1ll*(a-b+md)*inv%md;
}
}
}
}
FST:子集卷积
即\(h(i)=\sum\limits_{j or k=i且j and k=0} f(j)*g(k)\)。
比或卷积多了一个限制。
我们发现,设\(s(i)\)表示\(i\)的二进制表示中1的个数,那么如果\(i\|j=k,i\&j=0\),则\(s(i)+s(j)=s(k)\)。
利用这个性质,我们可以加一维表示\(s\),在\(F*G\)时考虑\(s\)的限制。
时间复杂度:\(O(nlog^2n)\)。
代码:
for(int i=0;i<len;i++)
{
for(int j=0;j<17;j++)
{
if(i&(1<<j))
sl[i]+=1;
}
}
for(int i=0;i<len;i++)
a[sl[i]][i]=sz[i];
for(int i=0;i<18;i++)
fwtor(a[i],len);
for(int i=0;i<18;i++)
{
for(int j=0;i+j<18;j++)
{
for(int k=0;k<len;k++)
h1[i+j][k]=(h1[i+j][k]+1ll*a[i][k]*a[j][k])%md;
}
}
for(int i=0;i<18;i++)
ifwtor(h1[i],len);
for(int i=0;i<len;i++)
ab[i]=h1[sl[i]][i];
例题:
CF914G
题意:
给你一个长度为\(n\)的数组\(s\).定义五元组\((a,b,c,d,e)\)是合法的当且仅当:
①.\(1\le a,b,c,d,e\le n\)
②.\((s_a|s_b)\&s_c\&(s_d\)^\(s_e)=2^i,i\in Z\)
③.\(s_a\&s_b=0\)
对于所有合法的五元组\((a,b,c,d,e)\)
求\(\sum f(s_a|s_b)*f(s_c)*f(s_d\)^\(s_e)\mod 10^9+7\)
\(f_0=0,f_1=1,f_i=f_{i-1}+f_{i-2}\)
\(1\le n\le10^6,0\le s_i\lt2^{17}\)
模板题。
先考虑\((s_a|s_b)\),发现是FST卷积。
再考虑\((s_d\)^\(s_e)\),是异或卷积。
然后就是与卷积了。
代码:
#include <stdio.h>
#define md 1000000007
#define inv 500000004
#define len 131072
int sz[132000],sl[132000];
void fwtor(int sz[132000],int n)
{
for(int i=2;i<=n;i=(i<<1))
{
for(int j=0;j<n;j+=i)
{
for(int k=0;k<(i>>1);k++)
sz[j+(i>>1)+k]=(sz[j+(i>>1)+k]+sz[j+k])%md;
}
}
}
void ifwtor(int sz[132000],int n)
{
for(int i=2;i<=n;i=(i<<1))
{
for(int j=0;j<n;j+=i)
{
for(int k=0;k<(i>>1);k++)
sz[j+(i>>1)+k]=(sz[j+(i>>1)+k]-sz[j+k]+md)%md;
}
}
}
void fwtand(int sz[132000],int n)
{
for(int i=2;i<=n;i=(i<<1))
{
for(int j=0;j<n;j+=i)
{
for(int k=0;k<(i>>1);k++)
sz[j+k]=(sz[j+k]+sz[j+(i>>1)+k])%md;
}
}
}
void ifwtand(int sz[132000],int n)
{
for(int i=2;i<=n;i=(i<<1))
{
for(int j=0;j<n;j+=i)
{
for(int k=0;k<(i>>1);k++)
sz[j+k]=(sz[j+k]-sz[j+(i>>1)+k]+md)%md;
}
}
}
void fwtxor(int sz[132000],int n)
{
for(int i=2;i<=n;i=(i<<1))
{
for(int j=0;j<n;j+=i)
{
for(int k=0;k<(i>>1);k++)
{
int a=sz[j+k],b=sz[j+(i>>1)+k];
sz[j+k]=(a+b)%md;
sz[j+(i>>1)+k]=(a-b+md)%md;
}
}
}
}
void ifwtxor(int sz[132000],int n)
{
for(int i=2;i<=n;i=(i<<1))
{
for(int j=0;j<n;j+=i)
{
for(int k=0;k<(i>>1);k++)
{
int a=sz[j+k],b=sz[j+(i>>1)+k];
sz[j+k]=1ll*(a+b)*inv%md;
sz[j+(i>>1)+k]=1ll*(a-b+md)*inv%md;
}
}
}
}
int a[18][132000],h1[18][132000],fib[132000],h2[132000];
int ab[132000],de[132000],ans[132000];
int main()
{
int n;
scanf("%d",&n);
for(int i=0;i<n;i++)
{
int a;
scanf("%d",&a);
sz[a]+=1;
}
for(int i=0;i<len;i++)
{
for(int j=0;j<17;j++)
{
if(i&(1<<j))
sl[i]+=1;
}
}
fib[0]=0;fib[1]=1;
for(int i=2;i<len;i++)
fib[i]=(fib[i-1]+fib[i-2])%md;
for(int i=0;i<len;i++)
a[sl[i]][i]=sz[i];
for(int i=0;i<18;i++)
fwtor(a[i],len);
for(int i=0;i<18;i++)
{
for(int j=0;i+j<18;j++)
{
for(int k=0;k<len;k++)
h1[i+j][k]=(h1[i+j][k]+1ll*a[i][k]*a[j][k])%md;
}
}
for(int i=0;i<18;i++)
ifwtor(h1[i],len);
for(int i=0;i<len;i++)
ab[i]=h1[sl[i]][i];
for(int i=0;i<len;i++)
de[i]=sz[i];
fwtxor(de,len);
for(int i=0;i<len;i++)
de[i]=1ll*de[i]*de[i]%md;
ifwtxor(de,len);
for(int i=0;i<len;i++)
{
ab[i]=1ll*ab[i]*fib[i]%md;
sz[i]=1ll*sz[i]*fib[i]%md;
de[i]=1ll*de[i]*fib[i]%md;
}
fwtand(ab,len);
fwtand(sz,len);
fwtand(de,len);
for(int i=0;i<len;i++)
ans[i]=1ll*ab[i]*sz[i]%md*de[i]%md;
ifwtand(ans,len);
int jg=0;
for(int i=1;i<=len;i=(i<<1))
jg=(jg+ans[i])%md;
printf("%d",jg);
return 0;
}
uoj310【UNR #2】黎明前的巧克力
题意:有一个集合,选出两个不相交的子集,使其异或和相等,问方案数。
考虑dp:设\(dp(i,j)\)表示考虑到i,两人异或为j的方案数。
则\(dp(i,j)=dp(i-1,j)+2*dp(i-1,j\)^\(a(i))\)。
考虑FWT:对每个i构造A,使\(A(0)=1,A(a(i))=2\)。
对每个A做FWT,乘起来后再IFWT。但是复杂度太高。
根据公式,可以发现,FWT(A)的每位只能是3或-1。
那么,只要知道FWT后,A的每个对应位置之和,就能解出3和-1的数量了,之后快速幂即可。
根据加法的运算律,可以得知若干个长度相等的序列FWT后对应位置求和,等于先求和,再FWT。
所以求和后,FWT,之后快速幂算出每个位置的值,再IFWT,最后位置0的值减1就是答案。
时间复杂度:\(O(mlogm)\)。
思路非常巧妙。
代码:
#include <stdio.h>
#define md 998244353
#define inv 499122177
#define len 1048576
void fwt(int sz[1050000],int n)
{
for(int i=2;i<=n;i=(i<<1))
{
int t=(i>>1);
for(int j=0;j<n;j+=i)
{
for(int k=j;k<j+t;k++)
{
int a=sz[k+t];
sz[k+t]=sz[k]-a;
sz[k]=sz[k]+a;
}
}
}
}
void ifwt(int sz[1050000],int n)
{
for(int i=2;i<=n;i=(i<<1))
{
int t=(i>>1);
for(int j=0;j<n;j+=i)
{
for(int k=j;k<j+t;k++)
{
int a=sz[k+t];
sz[k+t]=1ll*(sz[k]-a+md)*inv%md;
sz[k]=1ll*(sz[k]+a)*inv%md;
}
}
}
}
int sz[1050010],mi[1050000],m3[1050010],sl[1050000];
int main()
{
int n;
scanf("%d",&n);
for(int i=0;i<n;i++)
scanf("%d",&sz[i]);
sl[0]=n;
for(int i=0;i<n;i++)
sl[sz[i]]+=2;
fwt(sl,len);
for(int i=0;i<len;i++)
mi[i]=(n+sl[i])/4;
m3[0]=1;
for(int i=1;i<=n;i++)
m3[i]=3ll*m3[i-1]%md;
for(int i=0;i<len;i++)
{
if((n-mi[i])%2==0)
mi[i]=m3[mi[i]];
else
mi[i]=md-m3[mi[i]];
}
ifwt(mi,len);
printf("%d",(mi[0]-1+md)%md);
return 0;
}
扩展
这个技巧还可以扩展:
就是说当权值有k个时,我们先将其中一个权值变为0,这样总共可能的贡献有\(2^{k-1}\)种。(每个系数是1或-1),0的贡献一定是1。
为了把这\(2^{k-1}\)种的数量分别求出来,我们需要找\(2^{k-1}\)个等式。
可以枚举剩余k-1个元素的子集,只将其异或的位置+1,做FWT。
这样,每个位置,会得到\(2^{k-1}\)个数。将这些数做FWT后,就可以的得到\(2^{k-1}\)种可能分别的数量,快速幂即可。
证明略。
(核心)代码:
int xo=0,mi=1;
for(int i=0;i<n;i++)
{
for(int j=0;j<k;j++)
scanf("%d",&p[i][j]);
xo^=p[i][0];
for(int j=1;j<k;j++)
p[i][j]^=p[i][0];
}
ans[xo]=1;
fwtxor(ans,(1<<m));
for(int s=0;s<(1<<(k-1));s++)
{
for(int i=0;i<n;i++)
{
int z=0;
for(int j=1;j<k;j++)
{
if(s&(1<<(j-1)))
z^=p[i][j];
}
sz[z]+=1;
}
fwtxor(sz,(1<<m));
for(int i=0;i<(1<<m);i++)
{
nf[(i<<(k-1))|s]=sz[i];
sz[i]=0;
}
}
for(int i=0;i<(1<<m);i++)
{
for(int s=0;s<(1<<(k-1));s++)
zz[s]=nf[(i<<(k-1))|s];
ifwtxor(zz,1<<(k-1));
for(int s=0;s<(1<<(k-1));s++)
{
int he=sl[0];
for(int j=1;j<k;j++)
{
if(s&(1<<(j-1)))
he=(he-sl[j]+md)%md;
else
he=(he+sl[j])%md;
}
ans[i]=1ll*ans[i]*ksm(he,zz[s])%md;
}
}
ifwtxor(ans,(1<<m));
CF662C Binary Table
题意:
有一个 n 行 m 列的表格,每个元素都是 0/1 ,每次操作可以选择一行或一列,把 0/1 翻转,即把 0 换为 1 ,把 1 换为 0 。请问经过若干次操作后,表格中最少有多少个 1。\((1\leq n \leq 20,1\leq m \leq 10^5)\)。
首先,我们可以枚举行的交换,共\(2^n\)种。
然后,对每一列,考虑它是否交换。复杂度为\(O(nm2^n)\)。
考虑优化:
首先,我们发现,如果记翻转为1,那么翻转就是异或。
记B数组表示状态压缩后的每列的出现次数。
记A数组表示一列为这个状态的1的最少个数。
那么,设\(C_i=A_{ixorj}*B_j\)之和,那么C的最小值就是答案。
反一下,将A和B做异或卷积,即可得到C。时间复杂度\(O(nm+n2^n)\)。
代码:
#include <stdio.h>
#define ll long long
void fwt(ll sz[1048576],int n)
{
for(int h=2;h<=n;h=(h<<1))
{
for(int i=0;i<n;i+=h)
{
for(int j=0;j<(h>>1);j++)
{
ll a=sz[i+j],b=sz[i+j+(h>>1)];
sz[i+j]=a+b;
sz[i+j+(h>>1)]=a-b;
}
}
}
}
void ifwt(ll sz[1048576],int n)
{
for(int h=2;h<=n;h=(h<<1))
{
for(int i=0;i<n;i+=h)
{
for(int j=0;j<(h>>1);j++)
{
ll a=sz[i+j],b=sz[i+j+(h>>1)];
sz[i+j]=(a+b)/2;
sz[i+j+(h>>1)]=(a-b)/2;
}
}
}
}
int sz[20][100005];char zf[100005];
ll sa[1048576],sb[1048576];
int main()
{
int n,m;
scanf("%d%d",&n,&m);
for(int i=0;i<n;i++)
{
scanf("%s",zf);
for(int j=0;j<m;j++)
sz[i][j]=zf[j]-'0';
}
for(int i=0;i<(1<<n);i++)
{
int s=0;
for(int j=0;j<n;j++)
{
if(i&(1<<j))
s+=1;
}
sa[i]=n-s;
if(s<sa[i])sa[i]=s;
}
for(int i=0;i<m;i++)
{
int s=0;
for(int j=0;j<n;j++)
{
if(sz[j][i])
s|=(1<<j);
}
sb[s]+=1;
}
fwt(sa,1<<n);fwt(sb,1<<n);
for(int i=0;i<(1<<n);i++)
sa[i]*=sb[i];
ifwt(sa,1<<n);
int ans=99999999;
for(int i=0;i<(1<<n);i++)
{
if(sa[i]<ans)
ans=sa[i];
}
printf("%d",ans);
return 0;
}