传送门
原题
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 3000010
#define ll long long
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n, m, k=2;
int a[N][2], w[]={2, 1}, maxn;
const ll mod=998244353, inv2=(mod+1)>>1;
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
void fwt(ll* a, int len, int op) {
for (int i=1; i<len; i<<=1) {
for (int j=0,step=i<<1; j<len; j+=step) {
for (int k=j; k<j+i; ++k) {
ll t1=a[k], t2=a[k+i];
a[k]=(t1+t2)%mod, a[k+i]=(t1-t2)%mod;
if (op==-1) a[k]=a[k]*inv2%mod, a[k+i]=a[k+i]*inv2%mod;
}
}
}
}
namespace force{
ll ans[N], f[N];
void solve() {
for (int i=0; i<(1<<m); ++i) ans[i]=1;
for (int i=1; i<=n; ++i) {
memset(f, 0, sizeof(f));
for (int j=0; j<k; ++j) f[a[i][j]]=w[j];
fwt(f, 1<<m, 1);
for (int j=0; j<(1<<m); ++j) ans[j]=ans[j]*f[j]%mod;
}
fwt(ans, 1<<m, -1);
cout<<((ans[0]-1)%mod+mod)%mod<<endl;
}
}
namespace task{
int b[N];
ll f[N], cnt[N][4], ans[N];
void solve() {
for (int i=0; i<(1<<m); ++i) ans[i]=1;
for (int t=0; t<(1<<k); ++t) {
memset(f, 0, sizeof(f));
memset(b, 0, sizeof(b));
for (int i=1; i<=n; ++i) for (int j=0; j<k; ++j) if (t&(1<<j)) b[i]^=a[i][j];
for (int i=1; i<=n; ++i) ++f[b[i]];
fwt(f, 1<<m, 1);
for (int q=0; q<(1<<m); ++q) cnt[q][t]=f[q];
}
for (int i=0; i<(1<<m); ++i) {
// cout<<"i: "<<i<<endl;
// cout<<"cnt: "; for (int j=0; j<(1<<k); ++j) cout<<cnt[i][j]<<' '; cout<<endl;
fwt(cnt[i], 1<<k, -1);
// cout<<"cnt: "; for (int j=0; j<(1<<k); ++j) cout<<cnt[i][j]<<' '; cout<<endl;
for (int s=0; s<(1<<k); ++s) {
ll sum=0;
for (int j=0; j<k; ++j)
if (s&(1<<j)) sum=(sum-w[j])%mod;
else sum=(sum+w[j])%mod;
ans[i]=ans[i]*qpow(sum, cnt[i][s]);
}
}
fwt(ans, 1<<m, -1);
// cout<<"ans: "; for (int i=0; i<(1<<m); ++i) cout<<ans[i]<<' '; cout<<endl;
cout<<((ans[0]-1)%mod+mod)%mod<<endl;
}
}
signed main()
{
n=read();
for (int i=1; i<=n; ++i) maxn=max(maxn, a[i][0]=read()), a[i][1]=0;
for (m=1; (1<<m)<=maxn; ++m);
// force::solve();
task::solve();
return 0;
}