题解 [UNR #2] 黎明前的巧克力

传送门

原题

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 3000010
#define ll long long
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, m, k=2;
int a[N][2], w[]={2, 1}, maxn;
const ll mod=998244353, inv2=(mod+1)>>1;
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
void fwt(ll* a, int len, int op) {
	for (int i=1; i<len; i<<=1) {
		for (int j=0,step=i<<1; j<len; j+=step) {
			for (int k=j; k<j+i; ++k) {
				ll t1=a[k], t2=a[k+i];
				a[k]=(t1+t2)%mod, a[k+i]=(t1-t2)%mod;
				if (op==-1) a[k]=a[k]*inv2%mod, a[k+i]=a[k+i]*inv2%mod;
			}
		}
	}
}

namespace force{
	ll ans[N], f[N];
	void solve() {
		for (int i=0; i<(1<<m); ++i) ans[i]=1;
		for (int i=1; i<=n; ++i) {
			memset(f, 0, sizeof(f));
			for (int j=0; j<k; ++j) f[a[i][j]]=w[j];
			fwt(f, 1<<m, 1);
			for (int j=0; j<(1<<m); ++j) ans[j]=ans[j]*f[j]%mod;
		}
		fwt(ans, 1<<m, -1);
		cout<<((ans[0]-1)%mod+mod)%mod<<endl;
	}
}

namespace task{
	int b[N];
	ll f[N], cnt[N][4], ans[N];
	void solve() {
		for (int i=0; i<(1<<m); ++i) ans[i]=1;
		for (int t=0; t<(1<<k); ++t) {
			memset(f, 0, sizeof(f));
			memset(b, 0, sizeof(b));
			for (int i=1; i<=n; ++i) for (int j=0; j<k; ++j) if (t&(1<<j)) b[i]^=a[i][j];
			for (int i=1; i<=n; ++i) ++f[b[i]];
			fwt(f, 1<<m, 1);
			for (int q=0; q<(1<<m); ++q) cnt[q][t]=f[q];
		}
		for (int i=0; i<(1<<m); ++i) {
			// cout<<"i: "<<i<<endl;
			// cout<<"cnt: "; for (int j=0; j<(1<<k); ++j) cout<<cnt[i][j]<<' '; cout<<endl;
			fwt(cnt[i], 1<<k, -1);
			// cout<<"cnt: "; for (int j=0; j<(1<<k); ++j) cout<<cnt[i][j]<<' '; cout<<endl;
			for (int s=0; s<(1<<k); ++s) {
				ll sum=0;
				for (int j=0; j<k; ++j)
					if (s&(1<<j)) sum=(sum-w[j])%mod;
					else sum=(sum+w[j])%mod;
				ans[i]=ans[i]*qpow(sum, cnt[i][s]);
			}
		}
		fwt(ans, 1<<m, -1);
		// cout<<"ans: "; for (int i=0; i<(1<<m); ++i) cout<<ans[i]<<' '; cout<<endl;
		cout<<((ans[0]-1)%mod+mod)%mod<<endl;
	}
}

signed main()
{
	n=read();
	for (int i=1; i<=n; ++i) maxn=max(maxn, a[i][0]=read()), a[i][1]=0;
	for (m=1; (1<<m)<=maxn; ++m);
	// force::solve();
	task::solve();
	
	return 0;
}
posted @ 2022-05-05 21:25  Administrator-09  阅读(1)  评论(0编辑  收藏  举报