[AtCoder Regular Contest 156][D. Xor Sum 5]
题目链接:D - Xor Sum 5
题目大意:给定一个长度为 \(N(1\le N\le 1000)\) 的数组 \(A(0\le A_i\le 1000)\),以及一个正整数 \(K(1\le K\le 10^{12})\)。现在要在这 \(N\) 个数字里挑一个 \(A_i\) 出来,并重复 \(K\) 次这样的操作,对应的得分就是挑出来的 \(A_i\) 之和。一共有 \(N^K\) 种挑选方案,求所有挑选方案对应得分的异或和。
简化问题
将问题转化成生成函数来表示,那就是令 \(P(x)=(x^{A_1}+x^{A_2}+...+x^{A_N})^K\),并求所有系数为奇数对应项的指数之和。那么考虑模 \(2\) 意义下的多项式运算,就是求满足 \([x^S](x^{A_1}+x^{A_2}+...+x^{A_N})^K\) 的值为 \(1\) 对应的所有 \(S\) 的异或和。
接着考虑模 \(2\) 意义下的一个经典式子:
反复套用对应式子,我们得到:
其实对于这个式子,我们也可以用 \(\texttt{Lucas}\) 定理的推论来证明:
- 经典推论:\((x+y)^P\equiv x^P+y^P\pmod P\),其中 \(P\) 为质数(考虑 \(\binom{P}{i}\) 不为 \(P\) 的倍数当且仅当 \(i=0\) 或 \(P\))
- 对 \(P=2\),得出 \((x+y)^2\equiv x^2+y^2 \pmod 2\)
- 反复套用该式,得出 \((\sum_{i=1}^n x_i)^2\equiv (\sum_{i=1}^{n-1} x_i)^2+x_n^2\equiv ...\equiv \sum_{i=1}^n x_i^2\)
回到原问题,考虑 \(K\) 的二进制表达 \(K=\sum_{i=1}^{M}2^{k_i}\),那么就有:
记第 \(i\) 次选择的数字下标为 \(X_i\),于是问题就由求 \(\sum_{i=1}^K A_{X_i}\) 的异或和变成了求 \(\sum_{m=1}^M A_{X_m}\times 2^{k_m}=S\) 的异或和。
组合计数
我们可以将这一问题转化为一道组合计数问题:对每个结果 \(S\),确定 \(\sum_{m=1}^M A_{X_m}\times 2^{k_m}=S\) 的方案数。由于我们最后求的是异或和,那么我们只需要统计方案数为奇数的答案即可,进一步地,我们考虑求和结果在二进制的某一位上为 \(1\) 的方案数。
考虑动态维护所有满足方案数为奇数的 \(S\) 组成的集合。设当前考虑的是二进制第 \(i\) 位上的答案,显然更低位的信息我们是不需要的,所以我们动态维护所有的 \(\lfloor\frac{S}{2^i}\rfloor\)。又因为当 \(k_m>i\) 时,\(A_{X_m}\times 2^{k_m}\) 对答案的第 \(i\) 位不产生贡献,所以我们可以等到 \(i=k_m\) 时再考虑对应 \(A_{X_m}\times 2^{k_m}\) 的选择。
于是在过程中,当我们考虑到第 \(i\) 位时,集合中的元素个数。可以估算得出,集合中元素的最大值对应上界为:
所以在这个过程中,集合的大小始终是不超过 \(2\max (A)\) 的。于是我们直接暴力维护集合即可在 \(O(N^2\log K)\) 的复杂度内解决问题,每次统计答案时只需要统计集合内有多少个奇数即可。使用 \(\texttt{bitset}\) 可以进行进一步的常数优化。
要注意的是,由于在考虑前面几位的答案时,对于 \(k_m>i\) 对应的 \(A_{X_m}\times 2^{k_m}\) 选择方案数还没有被统计进来,所以实际上这时对应的方案数还要乘上一个 \(N\) 的若干次方 \(N^{阿巴阿巴}\)。因此当还存在 \(k_m>i\) 时,我们还需要根据 \(N\) 的奇偶性来判断当前位置是否可能有值。此外当 \(i=k_M\) 时我们也可以直接计算最终答案,具体实现见代码。
#include<bits/stdc++.h>
using namespace std;
#define N 2048
int n;
bitset<N>a,s,t;
long long k,ans;
int main()
{
scanf("%d%lld",&n,&k);
for(int i=1;i<=n;i++){
int x;
scanf("%d",&x);
a.flip(x);
}
s[0]=1;
for(int i=0;i<=40;i++){
for(int j=0;j<N;j++)if(s[j])
s.flip(j/2),s.flip(j);
if(k>>i&1){
t.reset();
k^=(1ll<<i);
for(int x=0;x<N/2;x++)if(a[x])
t=t^(s<<x);
s=t;
}
if(k==0){
long long res=0;
for(int j=0;j<N;j++)if(s[j])res^=j;
ans+=res<<i;
return printf("%lld\n",ans),0;
}
if(n&1){
long long o=0;
for(int j=1;j<N;j+=2)if(s[j])o^=1;
ans+=o<<i;
}
}
}