AT4996 [AGC034F] RNG and XOR

题目传送门

分析:
(太棒了学到虚脱)
先考虑暴力DP:

\[f_i=1+\sum_{j=0}^{2^n-1}p_jf_{i\oplus j} \]

感觉可以用FWT优化诶。。。
\(F\)\(f\)的集合幂级数,\(P\)\(p\)的集合幂级数,\(I\)为每一位都是1的集合幂级数
把上面的DP表达一下:

\[F*P+I=F+c \]

这里的乘法是异或卷积
这里\(c\)是一个常数,因为\(F(0)=0\)需要在常数位进行数值的调整
\(S(F)\)表示集合幂级数\(F\)每一位的和,发现:

\[S(F)*S(P)+S(I)=S(F)+c \]

带回最原始的DP式,对每个\(f_i\)求和,发现没有问题
由于\(S(P)=1,S(I)=2^n\),化简上式,发现\(c=2^n\)
对最初的式子变换一下:
\(F*P+I=F+c\)
\(F*(P-1)=c-I\)
发现是一个卷积形式,套FWT函数

\[FWT(F)_iFWT(P-1)_i=FWT(c-I)_i \]

\(i=0\)时,\(FWT(P-1)_i=(\sum_{j=0}^{2^n-1}p_i)-1=0\)
其余情况可以满足\(FWT(P-1)_i\)不为0,于是做除法:

\[FWT(F)_i=\frac{FWT(c-I)_i}{FWT(P-1)_i} \]

怎么求\(FWT(F)_0\)
把逆变换的式子拿出来:

\[f_i=\frac{1}{2^n}\sum_{j=0}^{2^n-1}(-1)^{popcount(i\&j)}FWT(F)_j \]

\(popcount(x)\)\(x\)二进制下1的个数
我们令\(i=0\),又因为\(f_i=0\),那么:

\[\sum_{j=0}^{2^n-1}FWT(F)_j=0 \]

\[FWT(F)_0=-\sum_{j=1}^{2^n-1}FWT(F)_j \]

逆变换回去,解出\(F\)即可
复杂度\(O(2^nn)\)

#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<vector>
#include<iostream>
#include<map>
#include<string>

#define maxn 3000005
#define MOD 998244353
#define inv2 499122177

using namespace std;

inline int getint()
{
	int num=0,flag=1;char c;
	while((c=getchar())<'0'||c>'9')if(c=='-')flag=-1;
	while(c>='0'&&c<='9')num=num*10+c-48,c=getchar();
	return num*flag;
}

int n,N;
int E[maxn],P[maxn],I[maxn];

inline int ksm(int num,int k)
{
	int ret=1;
	for(;k;k>>=1,num=1ll*num*num%MOD)if(k&1)ret=1ll*ret*num%MOD;
	return ret;
}
inline void FWT_xor(int *a,int opt,int N)
{
	for(int i=1;i<N;i<<=1)for(int j=0;j<N;j+=(i<<1))
		for(int k=0;k<i;k++)
		{
			int x=a[j+k],y=a[i+j+k];
			a[j+k]=(x+y)%MOD,a[i+j+k]=(x-y+MOD)%MOD;
			if(!~opt)a[j+k]=1ll*a[j+k]*inv2%MOD,a[i+j+k]=1ll*a[i+j+k]*inv2%MOD;
		}
}

int main()
{
	n=getint(),N=1<<n;int s=0;
	for(int i=0;i<N;i++)s+=P[i]=getint(),s%=MOD;
	s=ksm(s,MOD-2);
	for(int i=0;i<N;i++)P[i]=1ll*P[i]*s%MOD;P[0]--;
	I[0]=N-1,s=0;
	for(int i=1;i<N;i++)I[i]=MOD-1;
	FWT_xor(P,1,N),FWT_xor(I,1,N);
	for(int i=1;i<N;i++)E[i]=1ll*I[i]*ksm(P[i],MOD-2)%MOD,s=(s+E[i])%MOD;
	E[0]=MOD-s;
	FWT_xor(E,-1,N);
	for(int i=0;i<N;i++)printf("%d\n",E[i]);
}

posted @ 2020-07-02 15:21  Izayoi_Doyo  阅读(169)  评论(0编辑  收藏  举报