P4717-[模板]快速莫比乌斯/沃尔什变换(FMT/FWT)
正题
题目链接:https://www.luogu.com.cn/problem/P4717
题目大意
给出两个长度为\(2^n\)的数列\(A,B\)求
\[C_{n}=\sum_{i\ or\ j=n}A_iB_j
\]
\[C_{n}=\sum_{i\ and\ j=n}A_iB_j
\]
\[C_{n}=\sum_{i\ xor\ j=n}A_iB_j
\]
解题思路
和\(FFT\)一样的思路,我们需要将\(A,B\)转换为点集表达式相乘再转回去,这里就直接抛结论了。
对于\(or\)有$$FWT(A)=(FWT(A),FWT(A)+FWT(B))$$
\[IFWT(A)=(IFWT(A),IFWT(A)-IFWT(B))
\]
对于\(and\)有$$FWT(A)=(FWT(A)+FWT(B),FWT(B))$$
\[IFWT(A)=(IFWT(A)-IFWT(B),IFWT(B))
\]
对于\(xor\)有$$FWT(A)=(FWT(A)+FWT(B),FWT(A)-FWT(B))$$
\[IFWT(A)=(\frac{IFWT(A)+IFWT(B)}{2},\frac{IFWT(A)-IFWT(B)}{2})
\]
因为不用转成虚数所以代码好写很多,时间复杂度\(O(n\log n)\)
\(code\)
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const ll inv2=499122177,N=(1<<18),P=998244353;
ll lg,n,aa[N],bb[N],a[N],b[N],c[3][N];
void mul(ll *c,ll *a,ll *b){
for(ll i=0;i<n;i++)
c[i]=a[i]*b[i]%P;
return;
}
void FWT_or(ll *f,ll op){
for(ll p=2;p<=n;p<<=1){
ll len=p>>1;
for(ll k=0;k<n;k+=p)
for(ll i=k;i<k+len;i++)
(f[i+len]+=f[i]*op+P)%=P;
}
return;
}
void solve_or(ll *c){
memcpy(a,aa,sizeof(a));
memcpy(b,bb,sizeof(b));
FWT_or(a,1);FWT_or(b,1);
mul(c,a,b); FWT_or(c,P-1);
return;
}
void FWT_and(ll *f,ll op){
for(ll p=2;p<=n;p<<=1){
ll len=p>>1;
for(ll k=0;k<n;k+=p)
for(ll i=k;i<k+len;i++)
(f[i]+=f[i+len]*op+P)%=P;
}
return;
}
void solve_and(ll *c){
memcpy(a,aa,sizeof(a));
memcpy(b,bb,sizeof(b));
FWT_and(a,1);FWT_and(b,1);
mul(c,a,b); FWT_and(c,P-1);
return;
}
void FWT_xor(ll *f,ll op){
for(ll p=2;p<=n;p<<=1){
ll len=p>>1;
for(ll k=0;k<n;k+=p)
for(ll i=k;i<k+len;i++){
ll x=f[i],y=f[i+len];
f[i]=(x+y)*op%P;
f[i+len]=(x-y+P)*op%P;
}
}
return;
}
void solve_xor(ll *c){
memcpy(a,aa,sizeof(a));
memcpy(b,bb,sizeof(b));
FWT_xor(a,1);FWT_xor(b,1);
mul(c,a,b); FWT_xor(c,inv2);
return;
}
signed main()
{
scanf("%lld",&lg);n=1<<lg;
for(ll i=0;i<n;i++)scanf("%lld",&aa[i]);
for(ll i=0;i<n;i++)scanf("%lld",&bb[i]);
solve_or(c[0]);
solve_and(c[1]);
solve_xor(c[2]);
for(ll i=0;i<n;i++)printf("%lld ",c[0][i]);putchar('\n');
for(ll i=0;i<n;i++)printf("%lld ",c[1][i]);putchar('\n');
for(ll i=0;i<n;i++)printf("%lld ",c[2][i]);putchar('\n');
return 0;
}
------------恢复内容开始------------
------------恢复内容结束------------