题解 P6824 「EZEC-4」可乐
题意简述
给定一个长度为 \(n\) 的序列 \(a\) 和一个整数 \(k\) ,你需要构造一个数字 \(x\) ,使得 \(\sum_\limits{i=1}^n [a_i\oplus x\leq k]\) 最大,你需要求出这个最大值。
\(1 \leq n,k,a_i \leq 10^6\)。
Solution
这里给出一个递归的写法。
建出包含序列中所有数字的 Trie 树,然后考虑一个函数 \(solve(x,i)\) 表示当前在 Trie 的 \(x\) 节点上,并且当前节点的子树中的所所有数字与之前构造的 \(x\) 异或的最高的 \(i-1\) 位都和 \(k\) 的最高的 \(i-1\) 位相同的情况下能喝到的可乐的最大数量。最终答案为 \(solve(rt,\lfloor\log_2 k\rfloor)\) (\(rt\) 表示 Trie 的根节点)。
对于求解 \(solve(x,i)\) 的话可以分类讨论(下面用 \(ls(x)\) 表示 \(x\) 在 Trie 上的左儿子,\(rs(x)\) 表示 \(x\) 在 Trie 上的右儿子 ,\(s(x)\) 表示以 \(x\) 为根的子树内的数字个数):
- 若 \(k\) 的第 \(i\) 位为 \(1\) ,那么如果构造的 \(x\) 的第 \(i\) 位为 \(1\) ,当前子树中的所有第 \(i\) 位为 \(1\) 的数异或后都一定 \(< k\) ,第 \(i\) 位为 \(0\) 的需要递归处理,个数为 \(solve(ls(x),i-1)+s(rs(x))\),构造的 \(x\) 的第 \(i\) 位为 \(0\) 同理,个数为 \(solve(rs(x),i-1)+s(ls(x))\)。最后答案就是这两种情况取最大值。
- 若 \(k\) 的第 \(i\) 位为 \(0\),那么如果构造的 \(x\) 的第 \(i\) 位为 \(1\) ,那么只有第 \(i\) 位同为 \(1\) 的数才可能 \(\leq k\)(如果第 \(i\) 位为 \(0\) 的话异或之后就是 \(1\) 了),答案为 \(solve(rs(x),i-1)\) ,构造的 \(x\) 的第 \(i\) 位为 \(0\) 同理,个数为 \(solve(ls(x),i-1)\),最后答案就是这两种情况取最大值。
时间复杂度为 \(\mathcal{O}(k+n\log_2 n)\) (递归的复杂度是根据每个节点都会访问一次且仅访问一次确定的)。
这样写空间很大(我过不去),考虑有没有什么优化。
注意到建出 Trie 完全是为了方便找数,而不是利用什么树形性质。我们可以在递归的时候用一个 vector
来代替。也就是说,设 \(solve(v,i)\) 表示当前 \(v\) 中的所有数字与之前构造的 \(x\) 异或的最高的 \(i-1\) 位都和 \(k\) 的最高的 \(i-1\) 位相同,然后每次就把 \(v\) 分裂成两部分(按照第 \(i\) 位是 \(0\) 还是 \(1\)),递归方式和上面一样,只不过递归子节点变成了传递一个 vector
。
时间复杂度和上面 Trie 相同,空间复杂度的话由于递归之后会释放空间,所以同时空间最多是 \(\mathcal{O}(n \log_2 n)\),复杂度正确,而且跑得还挺快。
代码如下:
#include <cstdio>
#include <cstring>
#include <cctype>
#include <algorithm>
#include <iostream>
#include <queue>
using namespace std;
inline int read() {
int num = 0 ,f = 1; char c = getchar();
while (!isdigit(c)) f = c == '-' ? -1 : f ,c = getchar();
while (isdigit(c)) num = (num << 1) + (num << 3) + (c ^ 48) ,c = getchar();
return num * f;
}
const int N = 1e6 + 5 ,LG = 19;
int n ,k;
vector <int> v;
inline int solve(vector <int> v ,int i) {
// if (v.size() == 0) return 0;
if (i == 0) return v.size();
vector <int> v0 ,v1;
for (auto x : v) {
if (x >> i & 1) v1.push_back(x);
else v0.push_back(x);
}
if (k >> i & 1)
return max(v0.size() + solve(v1 ,i - 1) ,v1.size() + solve(v0 ,i - 1));
else
return max(solve(v0 ,i - 1) ,solve(v1 ,i - 1));
}
signed main() {
n = read() ,k = read();
for (int i = 1; i <= n; i++) {
int x = read();
v.push_back(x);
}
printf("%d\n" ,solve(v ,LG));
return 0;
}