快速沃尔什变换学习笔记
说好了在noip之前不学多项式算法……结果就真香了
快速沃尔什变换
给定长度为 \(2^n\) 两个序列 \(A,B\),设
分别当 \(\bigoplus\) 是 \(or,and,xor\) 时求出 \(C\)
\(n\leq 17\)
据说与 \(FFT\) 的核心思想相同,都是对数组的变换。
对于原数组A,B,在某种运算规则下,它们的结果很难求。但若是对该数组进行变换得到新数组 \(A'\) 和 \(B'\),而他们在该运算规则下的计算很好求,这时就能得到 \(C'\)。最后再对该序列进行逆变换即可获得答案 \(C\)。
记该变换数组为 \(FWT\),逆变换数组为 \(IFWT\)
or 运算下
定义 \(FWT[A] = \sum_{i|j=i}A_j\)
根据定义可推得
\(FWT[C] =FWT[A] \times FWT[B]\)
证明:
拆开 \(C\) 的定义式后,与上式形式相同。
证毕。
有了这个性质后我们接下来还需要解决两个问题
1:已知 A 如何快速求 FWT[A]
2:已知 FWT[A] 如何逆向求 A
记 A_0 为 A 下标中最高位为 0 的部分,A_1 为 A 下标中最高位为 1 的部分。
记\((G,K)\)表示将这两个序列前后接起来。
记 \(A + B\) 为$$\left{ A_1+B_1 ,A_2+B_2,A_3+B_3\dots A_n+B_n\right}$$
记 \(A \cdot B\)为
有
当\(2\leq|A|\)时
当\(n=1\)时$$FWT[A]=A$$
有
当\(2\leq|A|\)时
当\(n=1\)时$$IFWT[A]=A$$
按子集规规模理解即可,和下边的xor的数学归纳法证明相似
and 运算下
与运算同理
\(FWT[A]=\begin{cases}(FWT[A_0]+FWT[A_1],FWT[A_1])&2\leq n\\ A& n=1\end{cases}\)
\(IFWT[A]=\begin{cases}(IFWT[A_0]-IFWT[A_1],IFWT[A_1])&2\leq n\\ A& n=1\end{cases}\)
xor 运算下
突然就难了一个级别
定义FWT[A]如下定义
\(FWT[A]=\begin{cases}(FWT[A_0]+FWT[A_1],FWT[A_0]-FWT[A_1])&2\leq n\\ A& n=1\end{cases}\)
性质1:
\(FWT(A+B)=FWT(A)+FWT(B)\)
根据FWT[A]中每一维都是A中元素的线性组合可知
性质2:
\(FWT(A\bigoplus B)=FWT(A) \cdot FWT(B)\)
证明:
应用数学归纳法
\(n=1\)显然成立。
\(FWT(A⊕ B)=FWT((A⊕ B)_0,(A⊕ B)_1)\)
\(=FWT(A0⊕B0+A1⊕B1,A0⊕B1+A1⊕B0)\)
\(=FWT(A0⊕B0+A1⊕B1+A0⊕B1+A1⊕B0,A0⊕B0+A1⊕B1-A0⊕B1-A1⊕B0)\)
\(=(FWT(A0⊕B0)+FWT(A1⊕B1)+FWT(A0⊕B1)+FWT(A1⊕B0),FWT(A0⊕B0)+FWT(A1⊕B1)-FWT(A0⊕B1)-FWT(A1⊕B0))\)
\(=(FWT(A0)\cdot FWT(B0)+FWT(A1)\cdot FWT(B1)+FWT(A0)\cdot FWT(B1)+FWT(A1)\cdot FWT(B0),FWT(A0)\cdot FWT(B0)+FWT(A1)\cdot FWT(B1)-FWT(A0)\cdot FWT(B1)-FWT(A1)\cdot FWT(B0))\)
\(=(FWT(A0+A1)\cdot FWT(B0+B1),FWT(A0-A1)\cdot FWT(B0-B1))\)
(将这个式子做点乘得到上面那步
\(=(FWT(A0+A1)+FWT(A0-A1))\cdot (FWT(B0-B1)+FWT(B0-B1))\)
\(=FWT(A) \cdot FWT(B)\)
将\(\bigoplus\)拆解成点乘,就相当于数学归纳调用子问题
(证了这么多终于得到正向变换le
考虑逆向变换
\(IFWT[A]=\begin{cases}(\frac{1}{2}\times (IFWT[A_0]+IFWT[A_1]),\frac{1}{2}\times (IFWT[A_0]-IFWT[A_1]))&2\leq n\\ A& n=1\end{cases}\)
证明:
\(IFWT(FWT(A))=IFWT((FWT(A0+A1),FWT(A0-A1))\)
\(=(IFWT(FWT(A0)),IFWT(FWT(A1)))\)
\(=(A0,A1)\)
\(=A\)
证毕
粘一下板子
P4717 【模板】快速莫比乌斯/沃尔什变换 (FMT/FWT)
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define INF 1ll<<30
#define Int unsigned long long
template<typename _T>
inline void read(_T &x)
{
x=0;char s=getchar();int f=1;
while(s<'0'||s>'9') {f=1;if(s=='-')f=-1;s=getchar();}
while('0'<=s&&s<='9'){x=(x<<3)+(x<<1)+s-'0';s=getchar();}
x*=f;
}
#define lowbit(x) (x&(-x))
#define gb(x) ((x-1)/T + 1)
#define gl(x) ((x-1)*T + 1)
#define pb push_back
#define mod 998244353
const int np = (1<<17) + 5;
int A[np],B[np];
int A_[np],B_[np];
int c[np];
int n_,n;
inline int power(int a,int b)
{
int res = 1;
while(b)
{
if(b & 1) res = res * a % mod;
a = a * a;
a %= mod;
b>>=1;
}
return res;
}
inline void mul(){for(int i=0;i<n;i++)c[i] = A_[i] * B_[i] % mod;}
inline void FWTor(int *g,int opt)
{
for(int i=1;i < n;i <<= 1)
for(int o = 2 * i , j =0 ;j < n;j += o)
for(int k=0; k < i; k++)
{
g[j + k + i] += g[j + k] * opt;
g[j + k + i] = (g[j + k + i] + mod)%mod ;
}
}
inline void FWTand(int *g,int opt)
{
for(int i=1;i<n;i<<=1)
for(int o = 2 * i , j =0 ;j<n;j += o)
for(int k = 0;k<i;k++)
{
g[j + k] += g[j + k + i] * opt;
g[j + k] = (g[j +k] + mod) % mod;
}
}
inline void FWTxor(int *g,int opt)
{
for(int i=1;i<n;i<<=1)
for(int o=2 * i , j = 0;j<n;j += o)
for(int k=0;k<i;k++)
{
int x = g[j + k],y = g[j + k + i];
g[j + k] = (x + y) % mod;
g[j + k + i] = (x - y + mod) %mod;
if(opt != 1)
{
(g[j + k] *= opt)%=mod;
(g[j + k + i] *= opt)%=mod;
}
}
}
inline void Init()
{
for(int i=0;i<n;i++) A_[i] = A[i] , B_[i] = B[i];
}
signed main()
{
read(n_);
n = 1<<n_;
for(int i=0;i<n;i++) read(A[i]);
for(int i=0;i<n;i++) read(B[i]);
int inv = power(2,mod-2);
for(int i=0;i<n;i++) A_[i] = A[i] , B_[i] = B[i];
FWTor(A_,1);
FWTor(B_,1);
mul();
FWTor(c,-1);
for(int i=0;i<n;i++) printf("%lld ",c[i]);
printf("\n");
for(int i=0;i<n;i++) A_[i] = A[i] ,B_[i] = B[i];
FWTand(A_,1);
FWTand(B_,1);
mul();
FWTand(c,-1);
for(int i=0;i<n;i++) printf("%lld ",c[i]);
printf("\n");
Init();
FWTxor(A_,1);
FWTxor(B_,1);
mul();
FWTxor(c,inv);
for(int i=0;i<n;i++) printf("%lld ",c[i]);
// FWTxor();
}
CF449D Jzzhu and Numbers
写了一个 \(O(n \times 3^n)\) 暴力 dp
显然是过不了的,接下来有两种解决方案:
1:降维容斥
2:FWT科技解决问题
记 \(f_i\)为最后与出来的结果至少是\(i\),\(g_i\)为最后与出来的结果恰好是\(i\)。
那么有
则
将 \(g\) 继续展开有
现在我们考虑如何求 \(f\)
有\(f_x = 2^s-1\),其中 \(s\) 为 \(i\&x=x\)的数的个数
这个东西好像可以上沃尔什变换,然后 \(FWT\) 即可
代码循环展开了一下(为了卡最优解
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define INF 1ll<<30
#define Int unsigned long long
template<typename _T>
inline void read(_T &x)
{
x=0;char s=getchar();int f=1;
while(s<'0'||s>'9') {f=1;if(s=='-')f=-1;s=getchar();}
while('0'<=s&&s<='9'){x=(x<<3)+(x<<1)+s-'0';s=getchar();}
x*=f;
}
#define lowbit(x) (x&(-x))
#define gb(x) ((x-1)/T + 1)
#define gl(x) ((x-1)*T + 1)
#define pb push_back
#define Re register
#define MOD(x) (x = (x + mod)%mod)
const int mod = 1e9 + 7;
const int np = 3e6 + 5;
int dp[np],dp_[np];
int A_[np];
int G[np];
int bac[np] , bit[np];
int A[np],n_;
int a[np];
int power(int a,int b)
{
Re int res = 1;
while(b)
{
if(b & 1) res = a * res % mod;
a = a * a;
a %= mod;
b>>=1;
}
return res;
}
inline void FWT(int *g)
{
for(int i=1;i<n_;i<<=1)
for(register int o = 2*i,j=0;j<n_;j+=o)
for(register int k=0;k<i;k++)
{
int &d = g[j + k];
d += g[j + k + i];
d >= mod?d -mod:0;
}
}
signed main()
{
int n,maxn = 0;
read(n);
n_ = 1;
for(int i=1;i<=n;i++) read(a[i]),bac[a[i]]++, maxn = max(maxn , a[i]);
while(n_ <= maxn) n_ <<= 1;
FWT(bac);
bit[0] = 0;
for(int i=1;i<n_;i++)
{
bit[i] = bit[i - lowbit(i)] + 1;
}
for(int i=0;i<n_;i++)
{
bit[i] = (bit[i]&1)?-1:1;
}
Re int i(-1),Ans1(0),Ans2(0),Ans3(0),Ans4(0),Ans5(0),Ans6(0),Ans7(0),Ans8(0);
Re int f1(0),f2(0),f3(0),f4(0),f5(0),f6(0),f7(0),f8(0),Ans(0);
for(;i + 8<n_;i+=8)
{
f1 = power(2,bac[i + 1]) - 1;
f2 = power(2,bac[i + 2]) - 1;
f3 = power(2,bac[i + 3]) - 1;
f4 = power(2,bac[i + 4]) - 1;
f5 = power(2,bac[i + 5]) - 1;
f6 = power(2,bac[i + 6]) - 1;
f7 = power(2,bac[i + 7]) - 1;
f8 = power(2,bac[i + 8]) - 1;
Ans1 += f1 * bit[i + 1];
MOD(Ans1);
Ans2 += f2 * bit[i + 2];
Ans3 += f3 * bit[i + 3];
Ans4 += f4 * bit[i + 4];
Ans5 += f5 * bit[i + 5];
Ans6 += f6 * bit[i + 6];
Ans7 += f7 * bit[i + 7];
Ans8 += f8 * bit[i + 8];
Ans += Ans1 + Ans2 + Ans3 + Ans4 + Ans5 + Ans6 + Ans7 + Ans8;
Ans1 =Ans2 = Ans3 = Ans4 = Ans5 =Ans6 = Ans7 = Ans8 = 0;
MOD(Ans);
}
i++;
for(Re int f(0);i<n_;i++)
{
f = power(2,bac[i])-1;
f *= bit[i];
Ans += f;
if(f < 0) Ans = (Ans + mod) %mod;
else Ans %= mod;
}
Ans += Ans1 + Ans2 + Ans3 + Ans4 + Ans5 + Ans6 + Ans7 + Ans8;
cout<<Ans;
}