算法学习————FWT
解决问题
其中\(\bigoplus\)为or,and,xor,已知A和B,求解C
和FFT还有NTT的思想都是一样的,考虑在FFT的时候,我们是从系数法转化成点值法
对A和B本身FFT一次,想乘后得到C,然后用逆运算再把点值法转化成系数法,下面FWT也是一样的流程
Or
直接设\(FWTA(i) = \sum\limits_{j|i = i} A_j\)
假设可以通过\(FWTA(i)\)和\(FWTB(i)\)想乘得到\(FWTC(i)\),有\(FWTA(i)\times FWTB(i) = FWTC(i)\)
如何证明这个式子?
左边:
这里解释一下为什么j|i = i且k|i = i可以合并为j|k|i = i,我们可以知道或取得是两个数1的并集
这就说明j的1是i的1的子集,同理k,那么j和k的子集也是i的子集
右边:
易证两边等价
And
对于与也是一样的,设\(FWTA(i) = \sum\limits_{j \And i = i} A_i\)
左边:
右边:
Xor
比较麻烦的就是异或操作了
设\(FWTA(i) = \sum\limits_{j} A_j\times (-1)^{count(i\And j)}\)
那么左边:
右边:
观察两边的形式,那么我们现在只需要证明\((-1)^{count(i\And (x\bigoplus y))} = (-1)^{count(i\And x)+count(i\And y)}\)
我们可以知道左边异或操作把x和y共有的1给消掉了,而右边对于x和y共有的1,\((-1)^2 = 1\)乘1没有贡献,剩下的都是x和y不共有的1产生的贡献,数量相同
式子到这里就推完了,那么现在的问题就是如果快速的求出\(FWTA(i) FWTB(i)\),以及求逆的过程
我们把序列分成几段,刚开始长度为2:
000 001 | 010 011 | 100 101 | 110 111
先看对于或操作,我们可以知道一个数,包含的1是他的1的子集的数会对他产生贡献,就看块内前一个数会对后一个数产生贡献,贡献累加
对于与操作,则反过来,后一个数对前一个数有贡献,对于异或操作,好像可以推式子得来
之后我们可以每次把块长$\times$2,贡献一直累加,就能求出最后的答案
代码:
#include <iostream>
#include <cstring>
#include <algorithm>
#include <cstdio>
#define int long long
#define O(x) cout<<#x<<" "<<x<<endl;
#define o(x) cout<<#x<<" "<<x<<" ";
using namespace std;
int read(){
int x = 1,a = 0;char ch = getchar();
while (ch < '0'||ch > '9'){if (ch == '-') x = -1;ch = getchar();}
while (ch >= '0'&&ch <= '9'){a = a*10+ch-'0';ch = getchar();}
return x*a;
}
const int maxn = 2e6+10,mod = 998244353;
int n,inv;
int A[maxn],B[maxn];
int qpow(int a,int k){
int res = 1;a = a % mod;
while (k){
if (k&1) res = res*a % mod;
a = a*a % mod;
k >>= 1;
}
return res % mod;
}
int a[maxn],b[maxn];
void FWTOr(int a[],int op){
for (int l = 1;l < n;l <<= 1){
for (int i = 0;i < n;i += (l << 1)){
for (int j = 0;j < l;j++){
(a[i+j+l] += op*a[i+j]) %= mod;
}
}
}
}
void FWTAnd(int a[],int op){
for (int l = 1;l <= n;l <<= 1){
for (int i = 0;i < n;i += (l << 1)){
for (int j = 0;j < l;j++){
(a[i+j] += op*a[i+j+l]) %= mod;
}
}
}
}
void FWTXor(int a[],int op){
for (int l = 1;l < n;l <<= 1){
for (int i = 0;i < n;i += (l << 1)){
for (int j = 0;j < l;j++){
int x = a[i+j],y = a[i+j+l];
a[i+j] = (x+y) % mod,a[i+j+l] = (x+mod-y) % mod;
if (op == -1) a[i+j] = a[i+j]*inv % mod,a[i+j+l] = a[i+j+l]*inv % mod;
}
}
}
}
signed main(){
n = read(),n = (1 << n);
inv = qpow(2,mod-2);
for (int i = 0;i < n;i++) a[i] = read();
for (int i = 0;i < n;i++) b[i] = read();
for (int i = 0;i < n;i++) A[i] = a[i],B[i] = b[i];
FWTOr(A,1),FWTOr(B,1);
for (int i = 0;i < n;i++) A[i] = A[i]*B[i] % mod;
FWTOr(A,-1);
for (int i = 0;i < n;i++) cout<<((A[i] + mod) % mod)<<" ";
cout<<endl;
for (int i = 0;i < n;i++) A[i] = a[i],B[i] = b[i];
FWTAnd(A,1),FWTAnd(B,1);
for (int i = 0;i < n;i++) A[i] = A[i]*B[i] % mod;
FWTAnd(A,-1);
for (int i = 0;i < n;i++) cout<<((A[i] + mod) % mod)<<" ";
cout<<endl;
for (int i = 0;i < n;i++) A[i] = a[i],B[i] = b[i];
FWTXor(A,1),FWTXor(B,1);
for (int i = 0;i < n;i++) A[i] = A[i]*B[i] % mod;
FWTXor(A,-1);
for (int i = 0;i < n;i++) cout<<((A[i] + mod) % mod)<<" ";
cout<<endl;
return 0;
}