快速沃尔什变换(FWT)笔记

开头Orz hy,Orz yrx

部分转载自hy的博客

快速沃尔什变换,可以快速计算两个多项式的位运算卷积(即and,or和xor)

问题模型如下:

给出两个多项式$A(x)$,$B(x)$,求$C(x)$满足$C[i]=\sum\limits_{j⊗k=i}A[j]*B[k]$.

约定记号

$⊗$表示某种位运算(and,or和xor中的一种),若$a$,$b$是两个整数,则$a⊗b$表示对这两个数按位进行位运算;若$A$,$B$是两个多项式,则$A⊗B$表示对这两个多项式做如上卷积;两个多项式的点积用$·$表示。

FWT

感觉这个算法就是瞎凑出来的(大佬轻喷)

考虑对$A$和$B$做某种变换(类似FFT),使得变换之后对应位相乘之后逆运算就可以得到卷积$C(x)$。

那么这种变换$F(A)$(其中$A$是一个多项式)需要满足:

$F(A)·F(B)=F(A⊗B)$

$F(k\ast A)=k\ast F(A)$

$F(A+B)=F(A)+F(B)$

那就瞎凑呗,考虑用类似FFT的分治思想来解决,把多项式$A$的下标按照二进制最高位分类,最高位为0的记为$A_0$,为1的记为$A_1$,则$A=(A_0,A_1)$。

继续凑,设$F(A)=(k_{1}A_{0}+k_{2}A_{1},k_{3}A_{0}+k_{4}A_{1})$,那么要做的就是求出这四个常数。

不难发现$(k_{1},k_{2})$与$(k_{3},k_{4})$并没有本质上的区别,即求出了前半部分的多个解取其中两个代入即可。

将结果写作$(C_{0},C_{1})$,看回之前变换的定义,这里分类讨论:

对于$and$:

因为

$0⊗0=0$,$1⊗0=0$,$0⊗1=0$,$1⊗1=1$

所以

$(A_{0},A_{1})⊗(B_{0},B_{1})=(A_{0}⊗B_{0}+A_{0}⊗B_{1}+A_{1}⊗B_{0},A_{1}⊗B_{1})$

用两种方法表示$C$,可得

$(k_{1}A_{0}+k_{2}A_{1})·(k_{1}B_{0}+k_{2}B_{1})$

$=k_{1}(A_{0}⊗B_{0}+A_{0}⊗B_{1}+A_{1}⊗B_{0})+k_{2}(A_{1}⊗B_{1})$

拆括号得:

$k_{1}^{2}(A_{0}⊗B_{0})+k_{1}k_{2}(A_{0}⊗B_{1})+k_{1}k_{2}(A_{1}⊗B_{0})+k_{2}^{2}(A_{1}⊗B_{1})$

$=k_{1}(A_{0}⊗B_{0})+k_{1}(A_{0}⊗B_{1})+k_{1}(A_{1}⊗B_{0})+k_{2}(A_{1}⊗B_{1})$

则有:

$\begin{cases}
k_{1}=k_{1}^{2} \\
k_{1}=k_{1}k_{2} \\
k_{2}=k_{2}^{2} 
\end{cases}$

解得$\begin{cases} k_{1}=0 \\k_{2}=0\end{cases}$或$\begin{cases} k_{1}=1 \\k_{2}=0\end{cases}$或$\begin{cases} k_{1}=1 \\k_{2}=1\end{cases}$

考虑到要可以逆变换,所以解不能选两个相同的或者两个零(类似于求逆矩阵),因此这里只能选$(0,1)$和$(1,1)$两组解

令$(k_{1},k_{2})=(1,1)$,$(k_{3},k_{4})=(0,1)$

把系数写成矩阵,那么

$\begin{bmatrix} k_1 & k_2 \\ k_3 & k_4 \end{bmatrix} = \begin{bmatrix} 1 & 1 \\ 0 & 1 \end{bmatrix}$

把矩阵求逆,就可以得到逆变换的系数:

$\begin{bmatrix} 1 & -1 \\ 0 & 1 \end{bmatrix}$

对于$or$:$(A_{0},A_{1})⊗(B_{0},B_{1})=(A_{0}⊗B_{0},A_{0}⊗B_{1}+A_{1}⊗B_{0}+A_{1}⊗B_{1})$

正变换:$\begin{bmatrix} 1 & 1 \\ 1 & 0 \end{bmatrix}$

逆变换:$\begin{bmatrix} 0 & 1 \\ 1 & -1 \end{bmatrix}$

对于$xor$:$(A_{0},A_{1})⊗(B_{0},B_{1})=(A_{0}⊗B_{0}+A_{1}⊗B_{1},A_{0}⊗B_{1}+A_{1}⊗B_{0})$

正变换:$\begin{bmatrix} 1 & 1 \\ 1 & -1 \end{bmatrix}$

逆变换:$\begin{bmatrix} \frac{1}{2} & \frac{1}{2} \\ \frac{1}{2} & -\frac{1}{2} \end{bmatrix}$

代码(洛谷P4717):

 1 #include<iostream>
 2 #include<cstring>
 3 #include<cstdio>
 4 #include<cmath>
 5 #define OR 0
 6 #define AND 1
 7 #define XOR 2
 8 using namespace std;
 9 typedef long long ll;
10 const int mod=998244353;
11 int inv2,bit,bitnum,n,m,a[1000001],b[1000001],a1[1000001],a2[1000001],a3[1000001],b1[1000001],b2[1000001],b3[1000001];
12 int fastpow(int x,int y){
13     int ret=1;
14     for(;y;y>>=1,x=(ll)x*x%mod){
15         if(y&1)ret=(ll)ret*x%mod;
16     }
17     return ret;
18 }
19 void fwt(int s[],int n,int ty,int op){
20     for(int i=2;i<=n;i<<=1){
21         for(int j=0;j<n;j+=i){
22             for(int k=0;k<i/2;k++){
23                 int x=s[j+k],y=s[j+k+(i>>1)];//A0,A1
24                 if(op==1){
25                     if(ty==OR){
26                         s[j+k+(i>>1)]=(x+y)%mod;
27                     }else if(ty==AND){
28                         s[j+k]=(x+y)%mod;
29                     }else{
30                         s[j+k+(i>>1)]=(x+mod-y)%mod;
31                         s[j+k]=(x+y)%mod;
32                     }
33                 }else{
34                     if(ty==OR){
35                         s[j+k+(i>>1)]=((ll)y-x+mod)%mod;
36                     }else if(ty==AND){
37                         s[j+k]=(ll)((ll)x+mod-y)%mod;
38                     }else{
39                         s[j+k+(i>>1)]=(ll)((ll)x+mod-y)*inv2%mod;
40                         s[j+k]=(ll)(x+y)*inv2%mod;
41                     }
42                 }
43             }
44         }
45     }
46 }
47 int main(){
48     scanf("%d",&bitnum);
49     bit=(1<<bitnum);
50     inv2=fastpow(2,mod-2);
51     for(int i=0;i<bit;i++)scanf("%d",&a[i]),a1[i]=a[i],a2[i]=a[i],a3[i]=a[i];
52     for(int i=0;i<bit;i++)scanf("%d",&b[i]),b1[i]=b[i],b2[i]=b[i],b3[i]=b[i];
53     fwt(a1,bit,0,1);
54     fwt(a2,bit,1,1);
55     fwt(a3,bit,2,1);
56     fwt(b1,bit,0,1);
57     fwt(b2,bit,1,1);
58     fwt(b3,bit,2,1);
59     for(int i=0;i<bit;i++){
60         a1[i]=(ll)a1[i]*b1[i]%mod;
61         a2[i]=(ll)a2[i]*b2[i]%mod;
62         a3[i]=(ll)a3[i]*b3[i]%mod;
63     }
64     fwt(a1,bit,0,-1);
65     fwt(a2,bit,1,-1);
66     fwt(a3,bit,2,-1);
67     for(int i=0;i<bit;i++)printf("%d ",a1[i]);
68     printf("\n");
69     for(int i=0;i<bit;i++)printf("%d ",a2[i]);
70     printf("\n");
71     for(int i=0;i<bit;i++)printf("%d ",a3[i]);
72     return 0;
73 }
posted @ 2018-07-24 12:37  DCDCBigBig  阅读(291)  评论(0编辑  收藏  举报