【数学】快速Walsh变换
快速Walsh变换
给两个长度为 \(n\) 的序列 \(a,b\) ,满足 \(n=2^k\) ,序列标号为 \(a_0,a_1,\cdots,a_{n-1}\) ,
求AND卷积:
序列 \(c\) ,满足 \(c_i=\sum\limits_{(j\&k)=i}a_j\cdot b_k\)
求OR卷积:
序列 \(c\) ,满足 \(c_i=\sum\limits_{(j|k)=i}a_j\cdot b_k\)
求XOR卷积:
序列 \(c\) ,满足 \(c_i=\sum\limits_{(j\oplus k)=i}a_j\cdot b_k\)
复杂度 \(O(nlogn)\)
模数可以任意选,但是小心这个除以2是怎么实现,当模数是质数的时候可以用2的逆元。
namespace FWT {
const int MOD = 998244353;
const int INV2 = (MOD + 1) >> 1;
inline int add(const int &x, const int &y) {
int r = x + y;
if(r >= MOD)
r -= MOD;
return r;
}
inline int sub(const int &x, const int &y) {
int r = x - y;
if(r < 0)
r += MOD;
return r;
}
inline int mul(const int &x, const int &y) {
ll r = (ll)x * y;
if(r >= MOD)
r %= MOD;
return (int)r;
}
/* op =
+1 AND
-1 IAND
+2 OR
-2 IOR
+3 XOR
-3 IXOR
*/
void FWT(int *a, int n, int op) {
for(int l = 1; l < n; l <<= 1) {
for(int i = 0; i < n; i += (l << 1)) {
for(int j = 0; j < l; ++j) {
int x = a[i + j], y = a[i + j + l];
switch(op) {
case +1:
a[i + j] = add(x, y);
break;
case +2:
a[i + j + l] = add(x, y);
break;
case +3:
a[i + j] = add(x, y);
a[i + j + l] = sub(x, y);
break;
case -1:
a[i + j] = sub(x, y);
break;
case -2:
a[i + j + l] = sub(y, x);
break;
case -3:
a[i + j] = mul(add(x, y), INV2);
a[i + j + l] = mul(sub(x, y), INV2);
break;
default:
exit(-1);
}
}
}
}
}
/* op =
+1 AND
+2 OR
+3 XOR
*/
void Convolution(int *A, int *B, int n, int op) {
assert(__builtin_popcount(n) == 1);
FWT(A, n, op), FWT(B, n, op);
for(int i = 0; i < n; ++i)
A[i] = mul(A[i], B[i]);
FWT(A, n, -op);
}
};
子集卷积
求子集卷积:
序列 \(c\) ,满足 \(c_i=\sum\limits_{(j\& k)=0,(j| k)=i}a_j\cdot b_k\)
复杂度:\(O(n\log^2 n)\)
inline int cnt1(const int &x) {
return __builtin_popcount(x);
}
const int MOD = 1e9 + 9;
const int INV2 = (MOD + 1) >> 1;
inline int add(const int &x, const int &y) {
int r = x + y;
if(r >= MOD)
r -= MOD;
return r;
}
inline int sub(const int &x, const int &y) {
int r = x - y;
if(r < 0)
r += MOD;
return r;
}
inline int mul(const int &x, const int &y) {
ll r = (ll)x * y;
if(r >= MOD)
r %= MOD;
return (int)r;
}
void FWT(int *a, int n) {
for(int l = 1; l < n; l <<= 1) {
for(int i = 0; i < n; i += (l << 1)) {
for(int j = 0; j < l; ++j) {
int x = a[i + j], y = a[i + j + l];
a[i + j + l] = add(x, y);
}
}
}
}
void IFWT(int *a, int n) {
for(int l = 1; l < n; l <<= 1) {
for(int i = 0; i < n; i += (l << 1)) {
for(int j = 0; j < l; ++j) {
int x = a[i + j], y = a[i + j + l];
a[i + j + l] = sub(y, x);
}
}
}
}
int ln, n;
const int MAXLOGN = 20;
int a[MAXLOGN + 1][1 << MAXLOGN];
int b[MAXLOGN + 1][1 << MAXLOGN];
int c[MAXLOGN + 1][1 << MAXLOGN];
void Solve() {
ms(a), ms(b), ms(c);
scanf("%d", &ln), n = 1 << ln;
for(int i = 0; i < n; ++i)
scanf("%d", &a[cnt1(i)][i]);
for(int i = 0; i < n; ++i)
scanf("%d", &b[cnt1(i)][i]);
for(int x = 0; x <= ln; ++x) {
FWT(a[x], n), FWT(b[x], n);
for(int y = 0; y <= x; ++y)
for(int i = 0; i < n; ++i)
c[x][i] = add(c[x][i], mul(a[y][i], b[x - y][i]));
IFWT(c[x], n);
}
for(int i = 0; i < n; ++i)
printf("%d%c", c[cnt1(i)][i], " \n"[i == n - 1]);
}