题解 c
题解做法十分神仙
Sol 1:
设 \(F_i\) 为选了 \(i\) 个数的集合幂级数,其中 \([z^j](F_i)\) 为选了 \(i\) 个数异或和为 \(j\) 的方案数
再令 \(H\) 为所有数构成的集合幂级数,也即 \(\sum x^{a_i}\)
那么有转移
是在钦定第 \(i\) 个数和之前某个数重复并减去
然后考虑答案是什么样子的
这里令 \(F_m=\sum f_iH^i\)
那么现在只需要快速求出 \(F_m\) 的各次项系数就好了
按照定义求的话复杂度炸了
考虑上面那个递推式的转移路径
发现就是从 \([2, n]\) 中选若干个不相邻位置作为转移 2 的转移点,剩下的有转移 1
可以分治处理,合并两个区间需要分别维护这个区间的左/右端点选/不选时的多项式
求出来 \(F_m\) 之后还需要写一个多点求值
于是可以做到 \(O(n\log^2 n)\)
Sol 2:
选 m 个?可以背包嘛?
直接背包无法优化,考虑对 FWT 结果做背包
令 \(A_i\) 为第 \(i\) 个数构成的集合幂级数
一个合法方案将是 \(Ans=\prod\limits_{i=1}^m\operatorname{FWT}(A_i)\)
考虑 \([z^j]Ans\),它是从 \(n\) 个幂级数中选 \(m\) 个的所有方案的这一位的乘积之和
每个幂级数这一位都是 1 或 -1
那么设有 \(x\) 个 -1,有 \(y\) 个1,那么 \(x+y=n\)
然后根据 \(\sum\operatorname{FWT}(A_i)=\operatorname{FWT}(\sum A_i)\) 可以求出 \(-x+y\)
就可以解出 \(x, y\) 了
然后考虑上面提到的乘积之和
可以这样求: \([z^m](1-z)^x(1+z)^{n-x}\)
然后 \([z^k](1-z)^x=\dbinom{x}{k}(-1)^k\)
那么
然后后面这个式子大力展开之后可以 NTT
这样就可以做到 \(O(n\log n)\) 了
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 600010
#define ll long long
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans;
}
int n, m, k;
int a[N];
const ll mod=998244353, inv2=(mod+1)>>1, rt=3, phi=mod-1;
inline void md(ll& a, ll b) {a+=b; a=a>=mod?a-mod:a;}
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
namespace force{
int now;
ll f[2][130][128];
void solve() {
f[now][0][0]=1;
for (int i=1; i<=n; ++i,now^=1) {
memset(f[now^1], 0, sizeof(f[now^1]));
for (int j=0; j<=m; ++j) {
for (int s=0; s<k; ++s) if (f[now][j][s]) {
md(f[now^1][j+1][s^a[i]], f[now][j][s]);
md(f[now^1][j][s], f[now][j][s]);
}
}
}
for (int i=0; i<k; ++i) printf("%lld%c", f[now][m][i], " \n"[i==k-1]);
}
}
namespace force2{
int now;
ll f[2][2050][2050];
void solve() {
f[now][0][0]=1;
for (int i=1; i<=n; ++i,now^=1) {
memset(f[now^1], 0, sizeof(f[now^1]));
for (int j=0; j<=m; ++j) {
for (int s=0; s<k; ++s) if (f[now][j][s]) {
md(f[now^1][j+1][s^a[i]], f[now][j][s]);
md(f[now^1][j][s], f[now][j][s]);
}
}
}
for (int i=0; i<k; ++i) printf("%lld%c", f[now][m][i], " \n"[i==k-1]);
}
}
namespace task1{
int ans[N];
void solve() {
for (int i=1; i<=n; ++i) ++ans[a[i]];
for (int i=0; i<k; ++i) printf("%d%c", ans[i], " \n"[i==k-1]);
}
}
namespace task2{
ll ans[N], t[N];
void fwt(ll* a, int len, int op) {
for (int i=1; i<len; i<<=1) {
for (int j=0,step=i<<1; j<len; j+=step) {
for (int k=j; k<j+i; ++k) {
ll t1=a[k], t2=a[k+i];
a[k]=(t1+t2)%mod, a[k+i]=(t1-t2)%mod;
if (op==-1) a[k]=a[k]*inv2%mod, a[k+i]=a[k+i]*inv2%mod;
}
}
}
}
void solve() {
for (int i=1; i<=n; ++i) ++ans[a[i]];
// for (int i=0; i<k; ++i)
// for (int j=0; j<k; ++j)
// t[i^j]=(t[i^j]+ans[i]*ans[j])%mod;
// for (int i=0; i<k; ++i) printf("%lld%c", (t[i]%mod+mod)%mod, " \n"[i==k-1]);
fwt(ans, k, 1);
for (int i=0; i<k; ++i) ans[i]=ans[i]*ans[i]%mod;
fwt(ans, k, -1);
ans[0]-=n;
for (int i=0; i<k; ++i) printf("%lld%c", (ans[i]*inv2%mod+mod)%mod, " \n"[i==k-1]);
}
}
namespace task{
int rev[N], bln, bct;
ll f[N], g[N], val[N], fac[N], inv[N];
inline ll C(int n, int k) {return n<k?0ll:fac[n]*inv[n-k]%mod*inv[k]%mod;}
void ntt(ll* a, int len, int op) {
for (int i=0; i<len; ++i) if (i<rev[i]) swap(a[i], a[rev[i]]);
ll w, wn, t;
for (int i=1; i<len; i<<=1) {
wn=qpow(rt, (op*phi/(i<<1)+phi)%phi);
for (int j=0,step=i<<1; j<len; j+=step) {
w=1;
for (int k=j; k<j+i; ++k,w=w*wn%mod) {
t=a[k+i]*w%mod;
a[k+i]=(a[k]-t)%mod;
a[k]=(a[k]+t)%mod;
}
}
}
if (op==-1) {
ll inv=qpow(len, mod-2);
for (int i=0; i<len; ++i) a[i]=a[i]*inv%mod;
}
}
void fwt(ll* a, int len, int op) {
for (int i=1; i<len; i<<=1) {
for (int j=0,step=i<<1; j<len; j+=step) {
for (int k=j; k<j+i; ++k) {
ll t1=a[k], t2=a[k+i];
a[k]=(t1+t2)%mod, a[k+i]=(t1-t2)%mod;
if (op==-1) a[k]=a[k]*inv2%mod, a[k+i]=a[k+i]*inv2%mod;
}
}
}
}
void solve() {
int lim=max(max(n, m), k);
fac[0]=fac[1]=1; inv[0]=inv[1]=1;
for (int i=2; i<=lim; ++i) fac[i]=fac[i-1]*i%mod;
for (int i=2; i<=lim; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
for (int i=2; i<=lim; ++i) inv[i]=inv[i-1]*inv[i]%mod;
for (bln=1; bln<=2*n; bln<<=1,++bct);
for (int i=0; i<bln; ++i) rev[i]=rev[i>>1]>>1|((i&1)<<(bct-1));
for (int i=0; i<=n-m; ++i) f[i]=inv[i]*inv[n-m-i]%mod;
for (int i=0; i<=m; ++i) g[i]=(i&1?-1:1)*inv[i]*inv[m-i]%mod;
ntt(f, bln, 1); ntt(g, bln, 1);
for (int i=0; i<bln; ++i) f[i]=f[i]*g[i]%mod;
ntt(f, bln, -1);
for (int i=0; i<=n; ++i) f[i]=f[i]*fac[i]%mod*fac[n-i]%mod;
for (int i=1; i<=n; ++i) ++val[a[i]];
fwt(val, k, 1);
// cout<<"val: "; for (int i=0; i<k; ++i) cout<<val[i]<<' '; cout<<endl;
for (int i=0; i<k; ++i) {
int y=(n+val[i])>>1, x=n-y;
// cout<<"i: "<<i<<' '<<x<<' '<<y<<endl;
val[i]=f[x];
// cout<<"f[x]: "<<f[x]<<endl;
// val[i]=0;
// // for (int j=0; j<=min(m, x); ++j) val[i]=(val[i]+C(x, j)*(j&1?-1:1)*C(n-x, m-j))%mod;
// for (int j=0; j<=min(m, x); ++j) val[i]=(val[i]+qpow(-1, j)*inv[x-j]*inv[j]%mod*inv[n-m-(x-j)]%mod*inv[m-j])%mod;
// val[i]=val[i]*fac[x]%mod*fac[n-x]%mod;
// cout<<val[i]<<endl;
}
// cout<<"val: "; for (int i=0; i<k; ++i) cout<<val[i]<<' '; cout<<endl;
fwt(val, k, -1);
for (int i=0; i<k; ++i) printf("%lld%c", (val[i]%mod+mod)%mod, " \n"[i==k-1]);
}
}
signed main()
{
freopen("c.in", "r", stdin);
freopen("c.out", "w", stdout);
n=read(); m=read(); k=read();
for (int i=1; i<=n; ++i) a[i]=read();
// if (m==1) task1::solve();
// else if (m==2) task2::solve();
// else if (n<=128&&k<=128) force::solve();
// else force2::solve();
task::solve();
return 0;
}