FWT快速沃尔什变换

前言

学多项式怎么能错过\(FWT\)呢,然而这真是个毒瘤的东西,蒟蒻就只会背公式了\(\%>\_<\%\)

或卷积

\[\begin{aligned}\\ tf(A) = (tf(A_0), tf(A_1) + tf(A_0))\\ utf(A) = (utf(A), utf(A_1) - utf(A_0))\\ \end{aligned}\]

与卷积

\[\begin{aligned}\\ tf(A) = (tf(A_0) + tf(A_1), tf(A_1))\\ utf(A) = (utf(A_0) - utf(A_1), utf(A_1))\\ \end{aligned}\]

异或卷积

\[\begin{aligned}\\ tf(A) = (tf(A_0) + tf(A_1), tf(A_0) - tf(A_1))\\ utf(A) = (\frac{utf(A_0) + utf(A_1)}{2}, \frac{utf(A_0) - utf(A_1)}{2})\\ \end{aligned}\]

Code

习惯写递归的非递归本来也不会

#include<bits/stdc++.h>
typedef int LL;
inline LL Read(){
	LL x(0),f(1); char c=getchar();
	while(c<'0' || c>'9'){
		if(c=='-') f=-1; c=getchar();
	}
	while(c>='0' && c<='9'){
		x=(x<<3)+(x<<1)+c-'0'; c=getchar();
	}
	return x*f;
}
const LL mod=998244353,maxn=1<<18,inv2=499122177;
inline LL Pow(LL base,LL b){
	LL ret(1);
	while(b){
		if(b&1) ret=1ll*ret*base%mod; base=1ll*base*base%mod; b>>=1;
	}
	return ret;
}
void Solve_or(LL n,LL *a,LL *b,LL *c){
	n>>=1;
	if(!n){
		c[0]=1ll*a[0]*b[0]%mod;
		return;
	}
	for(LL i=0;i<n;++i){
		a[i+n]=1ll*(a[i+n]+a[i])%mod; b[i+n]=1ll*(b[i+n]+b[i])%mod;
	}
	Solve_or(n,a,b,c); Solve_or(n,a+n,b+n,c+n);
	for(LL i=0;i<n;++i) c[i+n]=(c[i+n]-c[i]+mod)%mod;
}
void Solve_and(LL n,LL *a,LL *b,LL *c){
	n>>=1;
	if(!n){
		c[0]=1ll*a[0]*b[0]%mod;
		return;
	}
	for(LL i=0;i<n;++i){
		a[i]=1ll*(a[i]+a[i+n])%mod; b[i]=1ll*(b[i]+b[i+n])%mod;
	}
	Solve_and(n,a,b,c); Solve_and(n,a+n,b+n,c+n);
	for(LL i=0;i<n;++i) c[i]=1ll*(c[i]-c[i+n]+mod)%mod;
}
void Solve_xor(LL n,LL *a,LL *b,LL *c){
	n>>=1;
	if(!n){
		c[0]=1ll*a[0]*b[0]%mod;
		return;
	}
	for(LL i=0;i<n;++i){
		std::tie(a[i],a[i+n])=std::make_tuple(a[i]+a[i+n],a[i]-a[i+n]+mod);
		std::tie(b[i],b[i+n])=std::make_tuple(b[i]+b[i+n],b[i]-b[i+n]+mod);
        a[i]%=mod; a[i+n]%=mod; b[i]%=mod; b[i+n]%=mod;
	}
	Solve_xor(n,a,b,c); Solve_xor(n,a+n,b+n,c+n);
    for(LL i=0;i<n;++i){
        std::tie(c[i],c[i+n])=std::make_tuple(c[i]+c[i+n],c[i]-c[i+n]+mod);
        c[i]=1ll*c[i]%mod*inv2%mod; c[i+n]=1ll*c[i+n]%mod*inv2%mod;
	}
}
LL n,N;
LL a[maxn],b[maxn],c[maxn],d[maxn],e[maxn],f[maxn],x[maxn],y[maxn],z[maxn];
int main(){
	n=Read();
	N=1<<n;
	for(LL i=0;i<N;++i) a[i]=c[i]=e[i]=Read();
	for(LL i=0;i<N;++i) b[i]=d[i]=f[i]=Read();
	Solve_or(N,a,b,x);
	Solve_and(N,c,d,y);
	Solve_xor(N,e,f,z);
	for(LL i=0;i<N;++i) printf("%d ",x[i]);printf("\n");
	for(LL i=0;i<N;++i) printf("%d ",y[i]);printf("\n");
	for(LL i=0;i<N;++i) printf("%d ",z[i]);printf("\n");
	return 0;
}
posted @ 2019-04-18 11:32  y2823774827y  阅读(188)  评论(0编辑  收藏  举报