题解 c

传送门

题解做法十分神仙

Sol 1:
\(F_i\) 为选了 \(i\) 个数的集合幂级数,其中 \([z^j](F_i)\) 为选了 \(i\) 个数异或和为 \(j\) 的方案数
再令 \(H\) 为所有数构成的集合幂级数,也即 \(\sum x^{a_i}\)
那么有转移

\[F_i=HF_{i-1}+(n-i+2)(i-1)F_{i-2} \]

是在钦定第 \(i\) 个数和之前某个数重复并减去
然后考虑答案是什么样子的

\[\begin{aligned} Ans&=F_m\\ \operatorname{FWT}(Ans)&=\operatorname{FWT}(F_m) \end{aligned}\]

这里令 \(F_m=\sum f_iH^i\)

\[\begin{aligned} \operatorname{FWT}(Ans)&=\operatorname{FWT}(\sum f_iH^i)\\ \operatorname{FWT}(Ans)&=\sum f_i\operatorname{FWT}^i(H)\\ \end{aligned}\]

那么现在只需要快速求出 \(F_m\) 的各次项系数就好了
按照定义求的话复杂度炸了
考虑上面那个递推式的转移路径
发现就是从 \([2, n]\) 中选若干个不相邻位置作为转移 2 的转移点,剩下的有转移 1
可以分治处理,合并两个区间需要分别维护这个区间的左/右端点选/不选时的多项式
求出来 \(F_m\) 之后还需要写一个多点求值
于是可以做到 \(O(n\log^2 n)\)

Sol 2:
选 m 个?可以背包嘛?
直接背包无法优化,考虑对 FWT 结果做背包
\(A_i\) 为第 \(i\) 个数构成的集合幂级数
一个合法方案将是 \(Ans=\prod\limits_{i=1}^m\operatorname{FWT}(A_i)\)
考虑 \([z^j]Ans\),它是从 \(n\) 个幂级数中选 \(m\) 个的所有方案的这一位的乘积之和
每个幂级数这一位都是 1 或 -1
那么设有 \(x\) 个 -1,有 \(y\) 个1,那么 \(x+y=n\)
然后根据 \(\sum\operatorname{FWT}(A_i)=\operatorname{FWT}(\sum A_i)\) 可以求出 \(-x+y\)
就可以解出 \(x, y\)
然后考虑上面提到的乘积之和
可以这样求: \([z^m](1-z)^x(1+z)^{n-x}\)
然后 \([z^k](1-z)^x=\dbinom{x}{k}(-1)^k\)
那么

\[\begin{aligned}[z^m](1-z)^x(1+z)^{n-x}&=\sum\limits_{i=0}^m\binom{x}{i}(-1)^i\binom{n-x}{m-i}\end{aligned} \]

然后后面这个式子大力展开之后可以 NTT
这样就可以做到 \(O(n\log n)\)

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 600010
#define ll long long
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans;
}

int n, m, k;
int a[N];
const ll mod=998244353, inv2=(mod+1)>>1, rt=3, phi=mod-1;
inline void md(ll& a, ll b) {a+=b; a=a>=mod?a-mod:a;}
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}

namespace force{
	int now;
	ll f[2][130][128];
	void solve() {
		f[now][0][0]=1;
		for (int i=1; i<=n; ++i,now^=1) {
			memset(f[now^1], 0, sizeof(f[now^1]));
			for (int j=0; j<=m; ++j) {
				for (int s=0; s<k; ++s) if (f[now][j][s]) {
					md(f[now^1][j+1][s^a[i]], f[now][j][s]);
					md(f[now^1][j][s], f[now][j][s]);
				}
			}
		}
		for (int i=0; i<k; ++i) printf("%lld%c", f[now][m][i], " \n"[i==k-1]);
	}
}

namespace force2{
	int now;
	ll f[2][2050][2050];
	void solve() {
		f[now][0][0]=1;
		for (int i=1; i<=n; ++i,now^=1) {
			memset(f[now^1], 0, sizeof(f[now^1]));
			for (int j=0; j<=m; ++j) {
				for (int s=0; s<k; ++s) if (f[now][j][s]) {
					md(f[now^1][j+1][s^a[i]], f[now][j][s]);
					md(f[now^1][j][s], f[now][j][s]);
				}
			}
		}
		for (int i=0; i<k; ++i) printf("%lld%c", f[now][m][i], " \n"[i==k-1]);
	}
}

namespace task1{
	int ans[N];
	void solve() {
		for (int i=1; i<=n; ++i) ++ans[a[i]];
		for (int i=0; i<k; ++i) printf("%d%c", ans[i], " \n"[i==k-1]);
	}
}

namespace task2{
	ll ans[N], t[N];
	void fwt(ll* a, int len, int op) {
		for (int i=1; i<len; i<<=1) {
			for (int j=0,step=i<<1; j<len; j+=step) {
				for (int k=j; k<j+i; ++k) {
					ll t1=a[k], t2=a[k+i];
					a[k]=(t1+t2)%mod, a[k+i]=(t1-t2)%mod;
					if (op==-1) a[k]=a[k]*inv2%mod, a[k+i]=a[k+i]*inv2%mod;
				}
			}
		}
	}
	void solve() {
		for (int i=1; i<=n; ++i) ++ans[a[i]];
		// for (int i=0; i<k; ++i)
		// 	for (int j=0; j<k; ++j)
		// 		t[i^j]=(t[i^j]+ans[i]*ans[j])%mod;
		// for (int i=0; i<k; ++i) printf("%lld%c", (t[i]%mod+mod)%mod, " \n"[i==k-1]);

		fwt(ans, k, 1);
		for (int i=0; i<k; ++i) ans[i]=ans[i]*ans[i]%mod;
		fwt(ans, k, -1);
		ans[0]-=n;
		for (int i=0; i<k; ++i) printf("%lld%c", (ans[i]*inv2%mod+mod)%mod, " \n"[i==k-1]);
	}
}

namespace task{
	int rev[N], bln, bct;
	ll f[N], g[N], val[N], fac[N], inv[N];
	inline ll C(int n, int k) {return n<k?0ll:fac[n]*inv[n-k]%mod*inv[k]%mod;}
	void ntt(ll* a, int len, int op) {
		for (int i=0; i<len; ++i) if (i<rev[i]) swap(a[i], a[rev[i]]);
		ll w, wn, t;
		for (int i=1; i<len; i<<=1) {
			wn=qpow(rt, (op*phi/(i<<1)+phi)%phi);
			for (int j=0,step=i<<1; j<len; j+=step) {
				w=1;
				for (int k=j; k<j+i; ++k,w=w*wn%mod) {
					t=a[k+i]*w%mod;
					a[k+i]=(a[k]-t)%mod;
					a[k]=(a[k]+t)%mod;
				}
			}
		}
		if (op==-1) {
			ll inv=qpow(len, mod-2);
			for (int i=0; i<len; ++i) a[i]=a[i]*inv%mod;
		}
	}
	void fwt(ll* a, int len, int op) {
		for (int i=1; i<len; i<<=1) {
			for (int j=0,step=i<<1; j<len; j+=step) {
				for (int k=j; k<j+i; ++k) {
					ll t1=a[k], t2=a[k+i];
					a[k]=(t1+t2)%mod, a[k+i]=(t1-t2)%mod;
					if (op==-1) a[k]=a[k]*inv2%mod, a[k+i]=a[k+i]*inv2%mod;
				}
			}
		}
	}
	void solve() {
		int lim=max(max(n, m), k);
		fac[0]=fac[1]=1; inv[0]=inv[1]=1;
		for (int i=2; i<=lim; ++i) fac[i]=fac[i-1]*i%mod;
		for (int i=2; i<=lim; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
		for (int i=2; i<=lim; ++i) inv[i]=inv[i-1]*inv[i]%mod;
		for (bln=1; bln<=2*n; bln<<=1,++bct);
		for (int i=0; i<bln; ++i) rev[i]=rev[i>>1]>>1|((i&1)<<(bct-1));
		for (int i=0; i<=n-m; ++i) f[i]=inv[i]*inv[n-m-i]%mod;
		for (int i=0; i<=m; ++i) g[i]=(i&1?-1:1)*inv[i]*inv[m-i]%mod;
		ntt(f, bln, 1); ntt(g, bln, 1);
		for (int i=0; i<bln; ++i) f[i]=f[i]*g[i]%mod;
		ntt(f, bln, -1);
		for (int i=0; i<=n; ++i) f[i]=f[i]*fac[i]%mod*fac[n-i]%mod;
		for (int i=1; i<=n; ++i) ++val[a[i]];
		fwt(val, k, 1);
		// cout<<"val: "; for (int i=0; i<k; ++i) cout<<val[i]<<' '; cout<<endl;
		for (int i=0; i<k; ++i) {
			int y=(n+val[i])>>1, x=n-y;
			// cout<<"i: "<<i<<' '<<x<<' '<<y<<endl;
			val[i]=f[x];
			// cout<<"f[x]: "<<f[x]<<endl;
			// val[i]=0;
			// // for (int j=0; j<=min(m, x); ++j) val[i]=(val[i]+C(x, j)*(j&1?-1:1)*C(n-x, m-j))%mod;
			// for (int j=0; j<=min(m, x); ++j) val[i]=(val[i]+qpow(-1, j)*inv[x-j]*inv[j]%mod*inv[n-m-(x-j)]%mod*inv[m-j])%mod;
			// val[i]=val[i]*fac[x]%mod*fac[n-x]%mod;
			// cout<<val[i]<<endl;
		}
		// cout<<"val: "; for (int i=0; i<k; ++i) cout<<val[i]<<' '; cout<<endl;
		fwt(val, k, -1);
		for (int i=0; i<k; ++i) printf("%lld%c", (val[i]%mod+mod)%mod, " \n"[i==k-1]);
	}
}

signed main()
{
	freopen("c.in", "r", stdin);
	freopen("c.out", "w", stdout);

	n=read(); m=read(); k=read();
	for (int i=1; i<=n; ++i) a[i]=read();
	// if (m==1) task1::solve();
	// else if (m==2) task2::solve();
	// else if (n<=128&&k<=128) force::solve();
	// else force2::solve();
	task::solve();

	return 0;
}
posted @ 2022-05-03 08:11  Administrator-09  阅读(3)  评论(0编辑  收藏  举报