[HDU 6057] Kanade's convolution

一、题目

点此看题

二、解法

你发现这个题是个 \(\tt FWT\) 的魔改版,回忆一下我们正常的 \(\tt FWT\) 只能做下面的式子:

\[C[k]=\sum_{i?j=k}A[i]\times B[j] \]

但是这个题离谱得给我们一次塞了三个位运算进去:

\[C[k]=\sum_{i\and j=k}A[i\oplus j]\times B[i\or j] \]

这时候不要被吓到了,我们要尽量套用 \(\tt FWT\) 的基本式子,令 \(x=i\oplus j,y=i\or j\) ,那么 \(i\and j=y-x\) ,那么尝试把全部都换成 \(x,y\) ,对于一对确定的 \((x,y)\) ,满足条件的 \(i,j\) 是有 \(2^{bit(x)}\) 对的(因为 \(x\)\(1\) 的位置上有两种组合),还要注意因为是乱枚举的 \((x,y)\) 所以 \(x\in y\) ,那么式子变成了:

\[C[k]=\sum_{x,y}[y-x=k][y\and x=x]B[y]\times A[x]\times 2^{bit(x)} \]

把减号也换成位运算,利用 \(x\in y\) 的关系:

\[C[k]=\sum_{x,y}[y\oplus x=k][y\and x=x]B[y]\times A[x]\times 2^{bit(x)} \]

\[C[k]=\sum_{x\oplus y=k}[y\and x=x]B[y]\times A[x]\times 2^{bit(x)} \]

能套 \(\tt FWT\) 的雏形已经出来了,现在就要解决这个讨厌的 \(y\and x=x\) ,利用特殊的异或运算,可以转化成数位的关系:

\[C[k]=\sum_{x\oplus y=k}[bit(y)-bit(x)=bit(k)]B[y]\times A[x]\times 2^{bit(x)} \]

这就变成完全能做的东西了,用类似子集卷积的方法就行了。

#include <cstdio>
#include <iostream>
using namespace std; 
const int M = 600005;
const int MOD = 998244353;
#define ll long long
int read()
{
	int x=0,f=1;char c;
	while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
	while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
	return x*f;
}
int n,m,inv,ans,a[20][M],b[20][M],c[20][M],pw[M],bit[M];
void fwt(int *a,int n,int op)
{
	for(int i=1;i<n;i<<=1)
		for(int p=i<<1,j=0;j<n;j+=p)
			for(int k=0;k<i;k++)
			{
				int x=a[j+k],y=a[i+j+k];
				a[j+k]=(x+y)%MOD;
				a[i+j+k]=(x-y+MOD)%MOD;
				if(op==-1)
				{
					a[j+k]=1ll*a[j+k]*inv%MOD;
					a[i+j+k]=1ll*a[i+j+k]*inv%MOD;
				}
			}
}
signed main()
{
	n=read();m=1<<n;
	inv=(MOD+1)/2;pw[0]=1;
	for(int i=1;i<m;i++)
	{
		bit[i]=bit[i>>1]+(i&1);
		pw[i]=pw[i-1]*2ll%MOD;
	}
	for(int i=0;i<m;i++)
		a[bit[i]][i]=1ll*read()*pw[bit[i]]%MOD;
	for(int i=0;i<m;i++)
		b[bit[i]][i]=read();
	for(int i=0;i<=n;i++)
		fwt(a[i],m,1),fwt(b[i],m,1);
	for(int i=0;i<=n;i++)
		for(int j=0;j<=i;j++)
			for(int k=0;k<m;k++)
				c[i-j][k]=(c[i-j][k]+1ll*b[i][k]*a[j][k])%MOD;
	for(int i=0;i<=n;i++)
		fwt(c[i],m,-1);
	ll x=1;
	for(int i=0;i<m;i++)
	{
		ans=(ans+x*c[bit[i]][i])%MOD;
		x=x*1526%MOD;
	}
	printf("%d\n",ans);
}
posted @ 2021-02-03 19:18  C202044zxy  阅读(42)  评论(0编辑  收藏  举报