【数学】快速Walsh变换

快速Walsh变换

给两个长度为 \(n\) 的序列 \(a,b\) ,满足 \(n=2^k\) ,序列标号为 \(a_0,a_1,\cdots,a_{n-1}\)

求AND卷积:

序列 \(c\) ,满足 \(c_i=\sum\limits_{(j\&k)=i}a_j\cdot b_k\)

求OR卷积:

序列 \(c\) ,满足 \(c_i=\sum\limits_{(j|k)=i}a_j\cdot b_k\)

求XOR卷积:

序列 \(c\) ,满足 \(c_i=\sum\limits_{(j\oplus k)=i}a_j\cdot b_k\)

复杂度 \(O(nlogn)\)

模数可以任意选,但是小心这个除以2是怎么实现,当模数是质数的时候可以用2的逆元。

namespace FWT {

    const int MOD = 998244353;
    const int INV2 = (MOD + 1) >> 1;

    inline int add(const int &x, const int &y) {
        int r = x + y;
        if(r >= MOD)
            r -= MOD;
        return r;
    }

    inline int sub(const int &x, const int &y) {
        int r = x - y;
        if(r < 0)
            r += MOD;
        return r;
    }

    inline int mul(const int &x, const int &y) {
        ll r = (ll)x * y;
        if(r >= MOD)
            r %= MOD;
        return (int)r;
    }

    /* op =
        +1 AND
        -1 IAND
        +2 OR
        -2 IOR
        +3 XOR
        -3 IXOR
     */
    void FWT(int *a, int n, int op) {
        for(int l = 1; l < n; l <<= 1) {
            for(int i = 0; i < n; i += (l << 1)) {
                for(int j = 0; j < l; ++j) {
                    int x = a[i + j], y = a[i + j + l];
                    switch(op) {
                    case +1:
                        a[i + j] = add(x, y);
                        break;
                    case +2:
                        a[i + j + l] = add(x, y);
                        break;
                    case +3:
                        a[i + j] = add(x, y);
                        a[i + j + l] = sub(x, y);
                        break;
                    case -1:
                        a[i + j] = sub(x, y);
                        break;
                    case -2:
                        a[i + j + l] = sub(y, x);
                        break;
                    case -3:
                        a[i + j] = mul(add(x, y), INV2);
                        a[i + j + l] = mul(sub(x, y), INV2);
                        break;
                    default:
                        exit(-1);
                    }
                }
            }
        }
    }

    /* op =
        +1 AND
        +2 OR
        +3 XOR
     */
    void Convolution(int *A, int *B, int n, int op) {
        assert(__builtin_popcount(n) == 1);
        FWT(A, n, op), FWT(B, n, op);
        for(int i = 0; i < n; ++i)
            A[i] = mul(A[i], B[i]);
        FWT(A, n, -op);
    }

};

子集卷积

求子集卷积:

序列 \(c\) ,满足 \(c_i=\sum\limits_{(j\& k)=0,(j| k)=i}a_j\cdot b_k\)

复杂度:\(O(n\log^2 n)\)

inline int cnt1(const int &x) {
    return __builtin_popcount(x);
}

const int MOD = 1e9 + 9;
const int INV2 = (MOD + 1) >> 1;

inline int add(const int &x, const int &y) {
    int r = x + y;
    if(r >= MOD)
        r -= MOD;
    return r;
}

inline int sub(const int &x, const int &y) {
    int r = x - y;
    if(r < 0)
        r += MOD;
    return r;
}

inline int mul(const int &x, const int &y) {
    ll r = (ll)x * y;
    if(r >= MOD)
        r %= MOD;
    return (int)r;
}

void FWT(int *a, int n) {
    for(int l = 1; l < n; l <<= 1) {
        for(int i = 0; i < n; i += (l << 1)) {
            for(int j = 0; j < l; ++j) {
                int x = a[i + j], y = a[i + j + l];
                a[i + j + l] = add(x, y);
            }
        }
    }
}

void IFWT(int *a, int n) {
    for(int l = 1; l < n; l <<= 1) {
        for(int i = 0; i < n; i += (l << 1)) {
            for(int j = 0; j < l; ++j) {
                int x = a[i + j], y = a[i + j + l];
                a[i + j + l] = sub(y, x);
            }
        }
    }
}

int ln, n;
const int MAXLOGN = 20;
int a[MAXLOGN + 1][1 << MAXLOGN];
int b[MAXLOGN + 1][1 << MAXLOGN];
int c[MAXLOGN + 1][1 << MAXLOGN];

void Solve() {
    ms(a), ms(b), ms(c);
    scanf("%d", &ln), n = 1 << ln;
    for(int i = 0; i < n; ++i)
        scanf("%d", &a[cnt1(i)][i]);
    for(int i = 0; i < n; ++i)
        scanf("%d", &b[cnt1(i)][i]);
    for(int x = 0; x <= ln; ++x) {
        FWT(a[x], n), FWT(b[x], n);
        for(int y = 0; y <= x; ++y)
            for(int i = 0; i < n; ++i)
                c[x][i] = add(c[x][i], mul(a[y][i], b[x - y][i]));
        IFWT(c[x], n);
    }
    for(int i = 0; i < n; ++i)
        printf("%d%c", c[cnt1(i)][i], " \n"[i == n - 1]);
}

posted @ 2021-02-05 18:59  purinliang  阅读(268)  评论(0编辑  收藏  举报