HDU 6057 Kanade's convolution(FWT)

 

【题目链接】 http://acm.hdu.edu.cn/showproblem.php?pid=6057

 

【题目大意】

  有 C[k]=∑_(i&j=k)A[i^j]*B[i|j]
  求 Ans=∑ C[i]*1526^i%998244353

 

【题解】

  将C[k]代入Ans的计算式得到 Ans=∑ A[i^j]*B[i|j]*1526^(i&j)%MOD
  我们发现(i^j)&(i&j)=0且(i^j)^(i&j)=i|j,
  因此bit[i^j]+bit[i&j]=bit[i|j],并有(i^j)|(i&j)=i|j
  设x=i^j, y=i&j, z=i|j 我们发现x&z=x,
  所以每对乘法乘上2^bit[x]的参数即可。
  我们计算1526^x和A[y]*2^bit[y]的or卷积,然后按位和B数组相乘。
  考虑bit[x]+bit[y]=bit[z]的卷积限制要求,我们将x和y按照bit进行分维,
  对于维度做和为bit[x]+bit[y]=bit[z]的子集FWT。

 

【代码】

#include <cstdio>
#include <algorithm>
#include <cstring> 
using namespace std;
typedef long long LL;
const int mod=998244353;
LL pow(LL a,LL b,LL p){LL t=1;for(a%=p;b;b>>=1LL,a=a*a%p)if(b&1LL)t=t*a%p;return t;}
void FWT(int*a,int n){
    for(int d=1;d<n;d<<=1)for(int m=d<<1,i=0;i<n;i+=m)for(int j=0;j<d;j++){
        int x=a[i+j],y=a[i+j+d];
        a[i+j+d]=(x+y)%mod;
    }
}
void UFWT(int*a,int n){
    for(int d=1;d<n;d<<=1)for(int m=d<<1,i=0;i<n;i+=m)for(int j=0;j<d;j++){
        int x=a[i+j],y=a[i+j+d];
        a[i+j+d]=(y-x+mod)%mod;
    }
}
const int N=1<<20;
int n; 
int A[21][N],B[21][N],C[21][N],bit[N],a[N],b[N],c[N];
int main(){
    while(~scanf("%d",&n)){
        int len=1<<n;
        for(int i=0;i<len;i++)scanf("%d",&a[i]);
        for(int i=0;i<len;i++)scanf("%d",&b[i]);
        for(int i=0;i<len;i++)bit[i]=bit[i>>1]+(i&1);
        memset(A,0,sizeof(A));
        memset(B,0,sizeof(B));
        memset(C,0,sizeof(C)); 
        LL t=1;
        for(int i=0;i<len;i++){
            A[bit[i]][i]=1LL*a[i]*(1<<bit[i])%mod;
            B[bit[i]][i]=t;
            t=t*1526%mod;
        }
        for(int i=0;i<=n;i++){
            FWT(A[i],len);
            FWT(B[i],len);
        }
        for(int k=0;k<=n;k++){
            for(int j=0;j+k<=n;j++){
                for(int i=0;i<len;i++)C[j+k][i]=(C[j+k][i]+1LL*A[j][i]*B[k][i]%mod)%mod;
            }
        }
        for(int i=0;i<=n;i++)UFWT(C[i],len);
        for(int i=0;i<len;i++)c[i]=C[bit[i]][i];
        LL ans=0;
        for(int i=0;i<len;i++)ans=(ans+(1LL*c[i]*b[i])%mod)%mod;
        printf("%d\n",ans);
    }return 0;
}
posted @ 2017-08-04 16:41  forever97  阅读(265)  评论(3编辑  收藏  举报