P4717-[模板]快速莫比乌斯/沃尔什变换(FMT/FWT)

正题

题目链接:https://www.luogu.com.cn/problem/P4717


题目大意

给出两个长度为\(2^n\)的数列\(A,B\)

\[C_{n}=\sum_{i\ or\ j=n}A_iB_j \]

\[C_{n}=\sum_{i\ and\ j=n}A_iB_j \]

\[C_{n}=\sum_{i\ xor\ j=n}A_iB_j \]


解题思路

\(FFT\)一样的思路,我们需要将\(A,B\)转换为点集表达式相乘再转回去,这里就直接抛结论了。

对于\(or\)有$$FWT(A)=(FWT(A),FWT(A)+FWT(B))$$

\[IFWT(A)=(IFWT(A),IFWT(A)-IFWT(B)) \]

对于\(and\)有$$FWT(A)=(FWT(A)+FWT(B),FWT(B))$$

\[IFWT(A)=(IFWT(A)-IFWT(B),IFWT(B)) \]

对于\(xor\)有$$FWT(A)=(FWT(A)+FWT(B),FWT(A)-FWT(B))$$

\[IFWT(A)=(\frac{IFWT(A)+IFWT(B)}{2},\frac{IFWT(A)-IFWT(B)}{2}) \]

因为不用转成虚数所以代码好写很多,时间复杂度\(O(n\log n)\)


\(code\)

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const ll inv2=499122177,N=(1<<18),P=998244353;
ll lg,n,aa[N],bb[N],a[N],b[N],c[3][N];

void mul(ll *c,ll *a,ll *b){
    for(ll i=0;i<n;i++)
        c[i]=a[i]*b[i]%P;
    return;
}

void FWT_or(ll *f,ll op){
    for(ll p=2;p<=n;p<<=1){
        ll len=p>>1;
        for(ll k=0;k<n;k+=p)
            for(ll i=k;i<k+len;i++)
                (f[i+len]+=f[i]*op+P)%=P;
    }
    return;
}
void solve_or(ll *c){
    memcpy(a,aa,sizeof(a));
    memcpy(b,bb,sizeof(b));
    FWT_or(a,1);FWT_or(b,1);
    mul(c,a,b); FWT_or(c,P-1);
    return;
}

void FWT_and(ll *f,ll op){
    for(ll p=2;p<=n;p<<=1){
        ll len=p>>1;
        for(ll k=0;k<n;k+=p)
            for(ll i=k;i<k+len;i++)
                (f[i]+=f[i+len]*op+P)%=P;
    }
    return;
}
void solve_and(ll *c){
    memcpy(a,aa,sizeof(a));
    memcpy(b,bb,sizeof(b));
    FWT_and(a,1);FWT_and(b,1);
    mul(c,a,b); FWT_and(c,P-1);
    return;
}

void FWT_xor(ll *f,ll op){
    for(ll p=2;p<=n;p<<=1){
        ll len=p>>1;
        for(ll k=0;k<n;k+=p)
            for(ll i=k;i<k+len;i++){
                ll x=f[i],y=f[i+len];
                f[i]=(x+y)*op%P;
                f[i+len]=(x-y+P)*op%P;
            }
    }
    return;
}
void solve_xor(ll *c){
    memcpy(a,aa,sizeof(a));
    memcpy(b,bb,sizeof(b));
    FWT_xor(a,1);FWT_xor(b,1);
    mul(c,a,b); FWT_xor(c,inv2);
    return;
}

signed main()
{
    scanf("%lld",&lg);n=1<<lg;
    for(ll i=0;i<n;i++)scanf("%lld",&aa[i]);
    for(ll i=0;i<n;i++)scanf("%lld",&bb[i]);
    solve_or(c[0]);
    solve_and(c[1]);
    solve_xor(c[2]);
    for(ll i=0;i<n;i++)printf("%lld ",c[0][i]);putchar('\n');
    for(ll i=0;i<n;i++)printf("%lld ",c[1][i]);putchar('\n');
    for(ll i=0;i<n;i++)printf("%lld ",c[2][i]);putchar('\n');
    return 0;
}

------------恢复内容开始------------

------------恢复内容结束------------

posted @ 2021-01-05 20:13  QuantAsk  阅读(118)  评论(0编辑  收藏  举报