快速沃尔什变换(FWT)学习笔记 + 洛谷P4717 [模板]

FWT求解的是一类问题:\( a[i] = \sum\limits_{j\bigoplus k=i}^{} b[j]*c[k] \)

其中,\( \bigoplus \) 可以是 or,and,xor

三种问题的解决思路都是对多项式 \( a \) 构造一个 \( a' \),令 \( a' = b' * c' \);

那么只需要把 \( b \) 变换成 \( b' \),\( c \) 变换成 \( c' \),然后乘出 \( a' \),再逆变换得到 \( a \);

下面问题就变成如何快速(logn)求 \( b \) 到 \( b' \) 的变换,这个变换就是 FWT;

始终要记住进行位运算的是位置(角标)而不是值;

一、or

构造 \( a'[i] = \sum\limits_{j|i=i}^{} a[j] \)

1.正变换

考虑把 \( a \) 分成前后两个部分 \( a0 \) 和 \( a1 \),先分别递归下去做好,得到 \( a0' \) 和 \( a1' \);

可以发现,\( a0' \) 和 \( a1' \) 的位置(角标)数字上唯一不同就是最高位是0或1;

但递归下去做的时候,\( a0' \) 和 \( a1' \) 的位置数字相当与去掉了最高位(因为折半了);

所以合并的时候,关键要考虑到最高位的0和1的不同:

(1) 对于 \( a' \) 的一个位置 \( i \) ,如果它在前半部分,那么它可以直接继承 \( a0'[i] \);

而 \( a1'[i]\) 由于实际上 \( i \) 还应该加上最高位的1,or 运算使它能贡献的位置最高位也是1,但 \( i \) 的最高位是0,所以不贡献给 \( a'[i] \) ;

(2)对于后半部分的 \( i \) ,\( a0'[i] \) 和 \( a1'[i] \) 都会对它产生贡献,因为两部分的位置数字都是 \( i \) 的子集;

所以可以得到:\( a' = \left ( a0' , a0'+a1' \right ) \)

递归的底层,只有一个元素的时候,\( a = a' \) ,于是我们可以递归做出正变换了;

当然,仿照 FFT 的写法即可,并不需要真的写递归函数,而且也不用蝴蝶变换;

2.逆变换

同样先考虑两个部分 \( a'0 \) 和 \( a'1 \) ,表示 \( a' \) 的前后部分;

已经做了 \( a' = \left ( a0' , a0'+a1' \right ) \)

现在要从 \( a' \) 拆出 \( a0' \) 和 \( a1' \)

那么 \( a0' = a'0 \)

而且 \( a1' = a'1 - a'0 \)

知道了 \( a0' \) 和 \( a0' \) ,就可以继续递归求解 \( a0 \) 和 \( a1 \),二者合起来就可以得到 \( a \)

递归的底层,只有一个元素的时候,\( a' = a \) ,于是我们可以递归作出逆变换了;

void fwt1(int *a,int tp)//a'=(a0',a0'+a1')  //a=(a0',a1'-a0')
{
  for(int mid=1;mid<lim;mid<<=1)
    for(int j=0,len=(mid<<1);j<lim;j+=len)
      for(int k=0;k<mid;k++)
      a[j+mid+k]=upt(a[j+mid+k]+tp*a[j+k]);
}
or

 

二、and

构造 \( a' = \sum\limits_{j \& i=i}^{} a[j] \)

1.正变换

和 or 同理,考虑最高位01的不同,后面继承本身,而前面要加上后面的贡献;

得到 \( a' = \left ( a0'+a1' , a1' \right ) \)

2.逆变换

同理,得到 

\( a0' = a'0 - a'1 \)

\( a1' = a'1 \)

void fwt2(int *a,int tp)//a'=(a0'+a1',a1')  //a=(a0'-a1',a1')
{
  for(int mid=1;mid<lim;mid<<=1)
    for(int j=0,len=(mid<<1);j<lim;j+=len)
      for(int k=0;k<mid;k++)
      a[j+k]=upt(a[j+k]+tp*a[j+mid+k]);
}
and

 

三、xor

设 \( d(i,j) \) 表示 \( i\&j \) 二进制表示中1的个数;

构造 \( a' = \sum\limits_{d(i,j)\%2==0}^{} a[j] - \sum\limits_{d(i,j)\%2==1}^{} a[j] \)

1.正变换

让我们三步走:

(1) \( a' = \left ( a0' + a1' , a0' - a1' \right ) \)

首先明确,\( a' \) 是 \( d(i\&j) \) 为偶数的 \( a[j] \) 求和,减去 \( d(i\&j) \) 为奇数的 \( a[j] \) 求和;

<1> 对于整体的一个位置 \( i \),它在前半部分

对于前半部分(折半)的相同位置 \( i' \),在前半部分的 \( j \) 中,\( d(i'\&j) \) 的奇偶性和 \( d(i\&j) \) 一样,所以继承答案;

对于后半部分(折半)的相同位置 \( i' \),在后半部分的 \( j \) 中,计算 \( i'\&j \) 时是没有考虑最高位的,所以它们的最高位上都是0,

而因为 \( i \) 的最高位是0,\( i\&j \) 的最高位同样是0,所以正好符合,答案可以加上;

也就是,\( a' = \left ( a0' + a1' , ... \right ) \)

<2> 对于整体的一个位置 \( i \),它在后半部分

对于前半部分(折半)的相同位置 \( i' \),在前半部分的 \( j \) 中,\( d(i'\&j) \) 的最高位都是0,

而因为 \( j \) 的最高位是0,\( i\&j \) 的最高位同样是0,所以正好符合,答案可以加上;

对于后半部分(折半)的相同位置 \( i' \),在后半部分的 \( j \) 中,计算 \( i'\&j \) 时是没有考虑最高位的,所以它们的最高位上都是0,

但 \( i\&j \) 的最高位是1,所以奇偶性都反了,答案加上的是负的;

这样,就得到 \( a' = \left ( a0' + a1' , a0' - a1' \right ) \)

 

(2) \( d(i\&k) \otimes d(j\&k) = d( (i \otimes j)\&k ) \)

因为是 \( \& \) ,我们就看 \( k \) 是1的那些位;

如果 \( d(i\&k) \) 是偶数,说明 \( i\&k \) 有偶数个1和 \( k \) 重合,奇数同理,\( j \) 同理;

<1> 当 \( d(i\&k) \) 和 \( d(j\&k) \) 奇偶性相同时

\( d(i\&k) + d(j\&k) \) 是偶数;

而 \( i \otimes j \) 同时消去 \( i \) 和 \( j \) 相同位置的1,不是 \( k \) 的1就算了,是 \( k \) 的1,消去的也是偶数;

所以 \( d( (i \otimes j)\&k ) \) 是偶数;

<2> 当 \( d(i\&k) \) 和 \( d(j\&k) \) 奇偶性不同时

\( d(i\&k) + d(j\&k) \) 是奇数;

而 \( i \otimes j \) 同时消去 \( i \) 和 \( j \) 相同位置的1,不是 \( k \) 的1就算了,是 \( k \) 的1,消去的是偶数;

所以 \( d( (i \otimes j)\&k ) \) 是奇数;

这样我们就证明了 \( d(i\&k) \otimes d(j\&k) = d( (i \otimes j)\&k ) \)

 

(3) 若 \( c[i] = \sum_{j \otimes k=i}^{} a[j]*b[k] \) ,有 \( c' = a' * b' \)

因为 \( c[i] = \sum_{j \otimes k=i}^{} a[j]*b[k] \)

又 \( c' = \sum_{d(i,j)\%2==0}^{} c[j] - \sum_{d(i,j)\%2==1}^{} c[j] \)

代入,得到 \( c'[i] = \sum_{d((j \otimes k)\&i)\%2==0}^{} a[j]*b[k] - \sum_{d((j \otimes k)\&i)\%2==1}^{} a[j]*b[k] \)

而 \( a'[i] * b'[i] = ( \sum_{d(i,j)\%2==0}^{} a[j] - \sum_{d(i,j)\%2==1}^{} a[j]) * ( \sum_{d(i,j)\%2==0}^{} b[j] - \sum_{d(i,j)\%2==1}^{} b[j]) \)

拆开再组合,并使用(2)得到的结论,就得到 \( a'[i] * b'[i] = ( \sum_{d((j \otimes k)\&i)\%2==0}^{} a[j]*b[k] - \sum_{d((j \otimes k)\&i))\%2==1}^{} a[j]*b[k]) \)

所以 \( c' = a' * b' \)

综上,我们仍然可以递归求 xor 的正变换,\( a' = \left ( a0' + a1' , a0' - a1' \right ) \)

 

2.逆变换

根据正变换就可以知道咯:

\( a0' = (a'0 + a'1) / 2 \)

\( a1' = (a'0 - a'1) / 2 \)

void fwt3(int *a,int tp)//a'=(a0'+a1',a0'-a1')  //a=((a0'+a1')/2,(a0'-a1')/2)
{
  for(int mid=1;mid<lim;mid<<=1)
    for(int j=0,len=(mid<<1);j<lim;j+=len)
      for(int k=0;k<mid;k++)
    {
      int x=a[j+k],y=a[j+mid+k];
      a[j+k]=upt(x+y); a[j+mid+k]=upt(x-y);
      if(tp==-1)a[j+k]=(ll)a[j+k]*inv%mod,a[j+mid+k]=(ll)a[j+mid+k]*inv%mod;
    }
}
xor

 

看例题:https://www.luogu.org/problemnew/show/P4717

代码如下:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
int const xn=(1<<17),mod=998244353;
int n,a[xn],b[xn],f[xn],g[xn],lim,inv;
int rd()
{
  int ret=0,f=1; char ch=getchar();
  while(ch<'0'||ch>'9'){if(ch=='0')f=0; ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return f?ret:-ret;
}
ll pw(ll a,int b)
{
  ll ret=1;
  for(;b;b>>=1,a=(a*a)%mod)
    if(b&1)ret=(ret*a)%mod;
  return ret;
}
int upt(int x){while(x>=mod)x-=mod; while(x<0)x+=mod; return x;}
void fwt1(int *a,int tp)//a'=(a0',a0'+a1')  //a=(a0',a1'-a0')
{
  for(int mid=1;mid<lim;mid<<=1)
    for(int j=0,len=(mid<<1);j<lim;j+=len)
      for(int k=0;k<mid;k++)
      a[j+mid+k]=upt(a[j+mid+k]+tp*a[j+k]);
}
void fwt2(int *a,int tp)//a'=(a0'+a1',a1')  //a=(a0'-a1',a1')
{
  for(int mid=1;mid<lim;mid<<=1)
    for(int j=0,len=(mid<<1);j<lim;j+=len)
      for(int k=0;k<mid;k++)
      a[j+k]=upt(a[j+k]+tp*a[j+mid+k]);
}
void fwt3(int *a,int tp)//a'=(a0'+a1',a0'-a1')  //a=((a0'+a1')/2,(a0'-a1')/2)
{
  for(int mid=1;mid<lim;mid<<=1)
    for(int j=0,len=(mid<<1);j<lim;j+=len)
      for(int k=0;k<mid;k++)
    {
      int x=a[j+k],y=a[j+mid+k];
      a[j+k]=upt(x+y); a[j+mid+k]=upt(x-y);
      if(tp==-1)a[j+k]=(ll)a[j+k]*inv%mod,a[j+mid+k]=(ll)a[j+mid+k]*inv%mod;
    }
}
int main()
{
  n=rd(); lim=(1<<n); inv=pw(2,mod-2);
  for(int i=0;i<lim;i++)a[i]=f[i]=rd();
  for(int i=0;i<lim;i++)b[i]=g[i]=rd();
  fwt1(f,1); fwt1(g,1);
  for(int i=0;i<lim;i++)f[i]=(ll)f[i]*g[i]%mod;
  fwt1(f,-1);
  for(int i=0;i<lim;i++)printf("%d ",f[i]); puts("");

  for(int i=0;i<lim;i++)f[i]=a[i],g[i]=b[i];
  fwt2(f,1); fwt2(g,1);
  for(int i=0;i<lim;i++)f[i]=(ll)f[i]*g[i]%mod;
  fwt2(f,-1);
  for(int i=0;i<lim;i++)printf("%d ",f[i]); puts("");

  for(int i=0;i<lim;i++)f[i]=a[i],g[i]=b[i];
  fwt3(f,1); fwt3(g,1);
  for(int i=0;i<lim;i++)f[i]=(ll)f[i]*g[i]%mod;
  fwt3(f,-1);
  for(int i=0;i<lim;i++)printf("%d ",f[i]); puts("");
  return 0;
}

 

参考博客:https://www.cnblogs.com/ACMLCZH/p/8022502.html

https://blog.csdn.net/neither_nor/article/details/60335099

posted @ 2018-11-29 16:34  Zinn  阅读(374)  评论(0编辑  收藏  举报