Atcoder Regular Contest 139 F - Many Xor Optimization Problems

提供一种不用推式子且比较好想的分治 NTT 做法,虽然复杂度要劣一点。

首先,题目要我们求子集异或和的最大值的和,自然需要建出这 \(n\) 个数的线性基。对于一组固定的 \(a_1\sim a_n\),假如我们已经求出了其极大线性无关的子集 \(a_{i_1},a_{i_2},\cdots,a_{i_k}\),并使用高斯消元将其削成了简化阶梯型的形式,那么答案自然就是线性基中这 \(k\) 个数的异或和,换句话说,答案第 \(d\) 位为 \(1\) 当且仅当:

  • \(d\) 是这组线性基的主元列
  • \(d\) 不是这组线性基的主元列,且将这组线性基削成了简化阶梯型的形式之后第 \(d\) 位上恰有奇数个 \(1\)

我们考虑枚举线性基大小 \(k\),那么发现求解线性基大小为 \(k\) 的所有矩阵的答案之和可以分为两个完全独立的部分:

  • 求解所有大小为 \(k\times m\) 的行线性无关的矩阵答案之和。
  • 求解以这 \(k\) 个向量为基底的 \(n\times m\) 的矩阵个数。

求出两部分的答案把它们乘起来再求和即可。

后面一部分的求解过程稍微容易点。我们假设字典序最小的一组基底的下标集合为 \(i_1,i_2,\cdots,i_k(i_1<i_2<\cdots<i_k)\),那么对于一组不在这个基底里的向量 \(a_j\),假设其位于 \(a_p\)\(a_{p+1}\) 之间,那么它必然可以表示为 \(a_{i_1},a_{i_2},\cdots,a_{i_p}\) 的线性组合,每个向量有 \(0/1\) 两种系数,所以 \(a_j\) 的取值方案数就是 \(2^p\)。换句话说,满足字典序最小的一组基底的下标集合为 \(i_1,i_2,\cdots,i_k(i_1<i_2<\cdots<i_k)\) 的矩阵个数为

\[\prod 2^{n-i_j-(k-j)} \]

这个问题可以进一步被抽象为,在 \(2^0,2^1,2^2,\cdots,2^{n-1}\)\(n\) 个幂中选 \(k\) 个乘起来再求和,可以使用 q-binomial 求解。

考虑怎么求前一部分的答案,先考虑暴力怎么做。我们枚举一个二进制位 \(d\),计算有多少个 \(k\times m\) 的行线性无关的矩阵答案的第 \(d\) 位为 \(1\),乘以 \(2^d\) 的系数以后求和。因为对于任一行线性无关的矩阵,其简化阶梯型是唯一的,而对于一个简化阶梯型,进行高斯消元后能得到其的矩阵个数就是 \(k\times k\) 线性无关矩阵的个数——根据经典结论,其等于 \(\prod\limits_{i=0}^{k-1}(2^k-2^i)\)。因此我们可以转而对简化阶梯型计数,这样只用再乘一个 \(\prod\limits_{i=0}^{k-1}(2^k-2^i)\) 即可。我们假设简化阶梯型的主元列为 \(a_1,a_2,\cdots,a_k(a_1<a_2<\cdots<a_k)\),根据前面的讨论可以想到分两类讨论:

  • \(d\in\{a_1,a_2,\cdots,a_k\}\):这种情况答案第 \(d\) 位必然为 \(1\),而对应这个简化阶梯型的矩阵个数为 \(\prod\limits_{i=1}^{k}2^{a_i-(i-1)}\),所以稍微抽象一下问题可以变为,从 \(2^0,2^1,2^2,\cdots,2^{n-1}\)\(n\) 个幂中选 \(k\) 个幂乘起来乘起来求和,其中有一个幂要乘两遍。分治 NTT 过程中维护 \(dp_{0/1,i}\) 表示当前区间里选了 \(i\) 个幂,选/没选那个乘两遍的幂即可做到 \(O(n\log^2n)\)
  • \(d\notin\{a_1,a_2,\cdots,a_k\}\):那么如果 \(d>a_k\)\(d\) 位必然不可能是 \(1\),否则恰有一半的概率简化阶梯型第 \(d\) 位上 \(1\) 的个数是奇数,因此求出以后 \(\prod\limits_{i=1}^{k}2^{a_i-(i-1)}\) 再成个 \(\dfrac{1}{2}\) 就是方案数。这个问题可以被抽象为从 \(2^0,2^1,2^2,\cdots,2^{n-1}\)\(n\) 个幂中选 \(k\) 个幂乘起来,再选一个没被选择的幂 \(2^i\) 作为关键点,并且要求存在一个被选择的幂 \(2^j\) 满足 \(j>i\)。分治 NTT 过程中维护 \(dp_{0/1/2,i}\) 表示当前区间里选了 \(i\) 个幂,没选关键点/选了关键点且后面没有被选的幂/选了关键点且后面有被选的幂的方案数可以做到 \(O(n\log^2n)\)
const int MAXN=2.5e5;
const int MOD=998244353;
const int MAXP=1<<19;
const int INV2=MOD+1>>1;
const int pr=3;
const int ipr=332748118;
int n,m,res,f[MAXN+5],facq[MAXN+5],ifacq[MAXN+5];
int qpow(int x,int e){int ret=1;for(;e;e>>=1,x=1ll*x*x%MOD)if(e&1)ret=1ll*ret*x%MOD;return ret;}
void init_fac(int n){
	facq[0]=1;
	for(int i=1;i<=n;i++)facq[i]=1ll*facq[i-1]*(qpow(2,i)-1)%MOD;
	for(int i=0;i<=n;i++)ifacq[i]=qpow(facq[i],MOD-2);
	for(int i=1;i<=n;i++)f[i]=1ll*facq[i]*qpow(2,1ll*i*(i-1)/2%(MOD-1))%MOD;
}
int calc(int n,int k){return 1ll*facq[n]*ifacq[k]%MOD*ifacq[n-k]%MOD;}
int rev[MAXP+5];
void NTT(vector<int>&a,int len,int type){
	int lg=31-__builtin_clz(len);
	for(int i=0;i<len;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<lg-1);
	for(int i=0;i<len;i++)if(rev[i]<i)swap(a[i],a[rev[i]]);
	for(int i=2;i<=len;i<<=1){
		int W=qpow((type<0)?ipr:pr,(MOD-1)/i);
		for(int j=0;j<len;j+=i){
			for(int k=0,w=1;k<(i>>1);k++,w=1ll*w*W%MOD){
				int X=a[j+k],Y=1ll*w*a[(i>>1)+j+k]%MOD;
				a[j+k]=(X+Y)%MOD;a[(i>>1)+j+k]=(X-Y+MOD)%MOD;
			}
		}
	}
	if(type<0){
		int iv=qpow(len,MOD-2);
		for(int i=0;i<len;i++)a[i]=1ll*a[i]*iv%MOD;
	}
}
vector<int>conv(vector<int>a,vector<int>b){
	int LEN=1,lim=a.size()+b.size()-1;while(LEN<a.size()+b.size())LEN<<=1;
	a.resize(LEN,0);b.resize(LEN,0);NTT(a,LEN,1);NTT(b,LEN,1);
	for(int i=0;i<LEN;i++)a[i]=1ll*a[i]*b[i]%MOD;
	NTT(a,LEN,-1);while(a.size()>lim)a.ppb();return a;
}
struct dat1{vector<int>dp[2];};
dat1 solve1(int l,int r){
	if(l==r){
		dat1 x;x.dp[0].resize(2);x.dp[1].resize(2);
		x.dp[0][0]=1;x.dp[0][1]=qpow(2,l);x.dp[1][1]=qpow(4,l);
		return x;
	}int mid=l+r>>1;dat1 L=solve1(l,mid),R=solve1(mid+1,r),res;
	res.dp[0]=conv(L.dp[0],R.dp[0]);res.dp[1].resize(r-l+2);
	vector<int>tmp=conv(L.dp[1],R.dp[0]);for(int i=0;i<=r-l+1;i++)res.dp[1][i]=(res.dp[1][i]+tmp[i])%MOD;
	           tmp=conv(L.dp[0],R.dp[1]);for(int i=0;i<=r-l+1;i++)res.dp[1][i]=(res.dp[1][i]+tmp[i])%MOD;
	return res;
}
struct dat2{vector<int>dp[3];};
dat2 solve2(int l,int r){
	if(l==r){
		dat2 x;x.dp[0].resize(2);x.dp[1].resize(2);x.dp[2].resize(2);
		x.dp[0][0]=1;x.dp[0][1]=qpow(2,l);x.dp[1][0]=qpow(2,l);
		return x;
	}int mid=l+r>>1;dat2 L=solve2(l,mid),R=solve2(mid+1,r),res;
	res.dp[0]=conv(L.dp[0],R.dp[0]);res.dp[1].resize(r-l+2);res.dp[2].resize(r-l+2);
	vector<int>tmp=conv(L.dp[0],R.dp[1]);for(int i=0;i<=r-l+1;i++)res.dp[1][i]=(res.dp[1][i]+tmp[i])%MOD;
	           tmp=conv(L.dp[0],R.dp[2]);for(int i=0;i<=r-l+1;i++)res.dp[2][i]=(res.dp[2][i]+tmp[i])%MOD;
	           tmp=conv(L.dp[2],R.dp[0]);for(int i=0;i<=r-l+1;i++)res.dp[2][i]=(res.dp[2][i]+tmp[i])%MOD;
	R.dp[0][0]=0;
	           tmp=conv(L.dp[1],R.dp[0]);for(int i=0;i<=r-l+1;i++)res.dp[2][i]=(res.dp[2][i]+tmp[i])%MOD;
	for(int i=0;i<L.dp[1].size();i++)res.dp[1][i]=(res.dp[1][i]+L.dp[1][i])%MOD;
	return res;
}
int main(){
	scanf("%d%d",&n,&m);init_fac(MAXN);dat1 dp1=solve1(0,m-1);dat2 dp2=solve2(0,m-1);
	for(int i=1;i<=min(n,m);i++)res=(res+1ll*(dp1.dp[1][i]+1ll*INV2*dp2.dp[2][i])%MOD*qpow(INV2,1ll*i*(i-1)/2%(MOD-1))%MOD*f[i]%MOD*calc(n,n-i))%MOD;
	printf("%d\n",res);
	return 0;
}
posted @ 2024-01-20 16:45  tzc_wk  阅读(33)  评论(0编辑  收藏  举报