P10218 [省选联考 2024] 魔法手杖 题解
Description
给定 \(a_1,a_2,\dots,a_n\) 以及 \(b_1,b_2,\dots,b_n\),满足 \(a_i \in [0,2^k-1]\) 以及 \(b_i\geq 0\),你需要给出 \(S \subseteq \{1,2,\dots,n\}\) 以及 \(x \in [0,2^k-1]\) 满足以下条件:
- \(\sum \limits_{i\in S} b_i\leq m\);
- 满足以上条件的前提下,最大化 \(val(S,x)=\min(\min \limits_{i \in S}(a_i+x),\min \limits_{i \in U \backslash S}(a_i \oplus x))\) 的值。
你只需要给出最大的 \(val(S,x)\) 的值即可。
设 \(\sum n\) 表示单组测试点各组数据 \(n\) 的和。对于所有测试数据,
- \(T \geq 1\);
- \(1 \leq n \leq 10^5\),\(1 \leq \sum n \leq 5\times 10^5\);
- \(0 \leq m \leq 10^9\);
- \(0 \leq k \leq 120\);
- \(\forall 1 \leq i \leq n, 0 \leq a_i<2^k\);
- \(\forall 1 \leq i \leq n, 0 \leq b_i \leq 10^9\)。
Solution
考虑建出 \(a\) 数组的 trie 树然后按照从高位到低位确定答案。
设当前走到第 \(k\) 位,trie 树节点权值为 \(val\),目前确定了 \(x\) 和 \(ans\) 的这些位,\(nowm\) 表示剩下的权值,\(minv\) 表示目前已经确定放到 \(S\) 中的数 \(a\) 的最小值。容易发现 \(x\oplus val=ans\)。
首先需要判断 \(ans\) 的第 \(k\) 位能否为 \(1\)。
不妨设 \(x\) 这位为 \(0\),对于 \(1\) 的情况同理。那么要让 \(ans\) 这位为 \(1\),就必须要把 \(val\) 子树内这位为 \(0\) 的数加到 \(S\) 中去,因为单靠异或是不能满足要求的。
那么如果 \(\min\{minv,minb_{ch_{val,0}}\}+2^k-1\geq ans+2^k\),就能满足 \(+\) 的部分大于等于 \(ans+2^k\),因为这里让后面 \(x\) 的位尽可能大一定是最优的。于是往这个儿子递归求答案即可。
然后是对于 \(ans\) 的第 \(k\) 位不能为 \(1\) 的情况。
还是设 \(x\) 这位为 \(0\)。那么 \(val\) 子树内为 \(1\) 的显然异或出来 \(\geq ans+2^k\),就不用管了,直接走到 \(ch_{val,0}\) 递归求解即可。
如果当前的 \(val\) 只有一个儿子需要特判。
容易发现对于 trie 树上的每个点只会遍历一次,所以时间复杂度是 \(O(nk)\)。
Code
#include <bits/stdc++.h>
// #define int int64_t
namespace FASTIO {
char ibuf[1 << 21], *p1 = ibuf, *p2 = ibuf;
char getc() {
return p1 == p2 && (p2 = (p1 = ibuf) + fread(ibuf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++;
}
template<class T> bool read(T &x) {
x = 0; int f = 0; char ch = getc();
while (ch < '0' || ch > '9') f |= ch == '-', ch = getc();
while (ch >= '0' && ch <= '9') x = (x * 10) + (ch ^ 48), ch = getc();
x = (f ? -x : x); return 1;
}
template<typename A, typename ...B> bool read(A &x, B &...y) { return read(x) && read(y...); }
char obuf[1 << 21], *o1 = obuf, *o2 = obuf + (1 << 21) - 1;
void flush() { fwrite(obuf, 1, o1 - obuf, stdout), o1 = obuf; }
void putc(char x) { *o1++ = x; if (o1 == o2) flush(); }
template<class T> void write(T x) {
if (!x) putc('0');
if (x < 0) x = -x, putc('-');
char c[40]; int tot = 0;
while (x) c[++tot] = x % 10, x /= 10;
for (int i = tot; i; --i) putc(c[i] + '0');
}
void write(char x) { putc(x); }
template<typename A, typename ...B> void write(A x, B ...y) { write(x), write(y...); }
struct Flusher {
~Flusher() { flush(); }
} flusher;
} // namespace FASTIO
using FASTIO::read; using FASTIO::putc; using FASTIO::write;
using i128 = __int128_t;
const int kMaxN = 1e5 + 5, kMaxK = 124, kMaxT = kMaxN * kMaxK;
int sid, n, m, k, tot;
int b[kMaxN], trie[kMaxT][2];
int64_t sumb[kMaxT];
i128 ans, a[kMaxN], mina[kMaxT];
void ins(i128 a, int b) {
int cur = 0;
for (int i = k - 1; ~i; --i) {
int c = (a >> i & 1);
if (!trie[cur][c]) trie[cur][c] = ++tot, mina[tot] = a;
cur = trie[cur][c];
mina[cur] = std::min(mina[cur], a), sumb[cur] += b;
}
}
void solve(int cur, i128 nowval, i128 x, i128 nowans, i128 minv, int m, int k) {
assert((x ^ nowans) == nowval);
assert(m >= 0);
if (!~k) {
ans = std::max(ans, nowans);
return;
}
assert(trie[cur][0] || trie[cur][1]);
i128 pw = (i128)1 << k;
if (!trie[cur][0]) {
// x 这位为 0
bool fl = 0;
if (minv + x + pw - 1 >= nowans + pw) {
fl = 1;
solve(trie[cur][1], nowval + pw, x, nowans + pw, minv, m, k - 1);
}
// x 这位为 1
if (std::min(minv, mina[trie[cur][1]]) + x + pw + pw - 1 >= nowans + pw && m >= sumb[trie[cur][1]]) {
fl = 1;
ans = std::max(ans, std::min(std::min(minv, mina[trie[cur][1]]) + x + pw + pw - 1, nowans + pw + pw - 1));
}
if (!fl) {
ans = std::max(ans, std::min(nowans + pw - 1, minv + x + pw - 1));
solve(trie[cur][1], nowval + pw, x + pw, nowans, minv, m, k - 1);
}
} else if (!trie[cur][1]) {
bool fl = 0;
// x 这位为 1
if (minv + x + pw + pw - 1 >= nowans + pw) {
fl = 1;
solve(trie[cur][0], nowval, x + pw, nowans + pw, minv, m, k - 1);
}
// x 这位为 0
if (std::min(minv, mina[trie[cur][0]]) + x + pw - 1 >= nowans + pw && m >= sumb[trie[cur][0]]) {
fl = 1;
ans = std::max(ans, std::min(std::min(minv, mina[trie[cur][1]]) + x + pw - 1, nowans + pw + pw - 1));
}
if (!fl) {
ans = std::max(ans, std::min(nowans + pw - 1, minv + x + pw + pw - 1));
solve(trie[cur][0], nowval, x, nowans, minv, m, k - 1);
}
} else {
// ans 这位为 1
assert(trie[cur][0] && trie[cur][1]);
bool fl = 0;
if (std::min(minv, mina[trie[cur][0]]) + x + pw - 1 >= nowans + pw && m >= sumb[trie[cur][0]]) {
fl = 1;
solve(trie[cur][1], nowval + pw, x, nowans + pw, std::min(minv, mina[trie[cur][0]]), m - sumb[trie[cur][0]], k - 1);
}
if (std::min(minv, mina[trie[cur][1]]) + x + pw + pw - 1 >= nowans + pw && m >= sumb[trie[cur][1]]) {
fl = 1;
solve(trie[cur][0], nowval, x + pw, nowans + pw, std::min(minv, mina[trie[cur][1]]), m - sumb[trie[cur][1]], k - 1);
}
if (!fl) {
solve(trie[cur][0], nowval, x, nowans, minv, m, k - 1);
solve(trie[cur][1], nowval + pw, x + pw, nowans, minv, m, k - 1);
}
}
}
void dickdreamer() {
read(n, m, k);
for (int i = 0; i <= tot; ++i) {
trie[i][0] = trie[i][1] = sumb[i] = mina[i] = 0;
}
tot = 0;
for (int i = 1; i <= n; ++i) read(a[i]);
int64_t tsb = 0;
for (int i = 1; i <= n; ++i) {
read(b[i]);
ins(a[i], b[i]);
tsb += b[i];
}
ans = 0, solve(0, 0, 0, 0, ((i128)1 << k) - 1, m, k - 1);
if (tsb <= m) {
ans = std::max(ans, *std::min_element(a + 1, a + 1 + n) + ((i128)1 << k) - 1);
}
write(ans, '\n');
}
int32_t main() {
#ifdef ORZXKR
freopen("in.txt", "r", stdin);
freopen("out.txt", "w", stdout);
#endif
int T = 1;
read(sid, T);
while (T--) dickdreamer();
// std::cerr << 1.0 * clock() / CLOCKS_PER_SEC << "s\n";
return 0;
}