[十二省联考 2019]异或粽子[题解]
[十二省联考 2019] 异或粽子
\(Problem\)
小粽是一个喜欢吃粽子的好孩子。今天她在家里自己做起了粽子。
小粽面前有 \(n\) 种互不相同的粽子馅儿,小粽将它们摆放为了一排,并从左至右编号为 \(1\) 到 \(n\)。第 \(i\) 种馅儿具有一个非负整数的属性值 \(a_i\)。每种馅儿的数量都足够多,即小粽不会因为缺少原料而做不出想要的粽子。小粽准备用这些馅儿来做出 \(k\) 个粽子。
小粽的做法是:选两个整数数 \(l\), \(r\),满足 \(1 \leqslant l \leqslant r \leqslant n\),将编号在 \([l, r]\) 范围内的所有馅儿混合做成一个粽子,所得的粽子的美味度为这些粽子的属性值的异或和。(异或就是我们常说的 xor 运算,即 C/C++ 中的 ˆ
运算符或 Pascal 中的 xor
运算符)
小粽想品尝不同口味的粽子,因此它不希望用同样的馅儿的集合做出一个以上的
粽子。
小粽希望她做出的所有粽子的美味度之和最大。请你帮她求出这个值吧!
数据范围
测试点 | \(n\) | \(k\) |
---|---|---|
\(1\), \(2\), \(3\), \(4\), \(5\), \(6\), \(7\), \(8\) | \(\leqslant 10^3\) | \(\leqslant 10^3\) |
\(9\), \(10\), \(11\), \(12\) | \(\leqslant 5 \times 10^5\) | \(\leqslant 10^3\) |
\(13\), \(14\), \(15\), \(16\) | \(\leqslant 10^3\) | \(\leqslant 2 \times 10^5\) |
\(17\), \(18\), \(19\), \(20\) | \(\leqslant 5 \times 10^5\) | \(\leqslant 2 \times 10^5\) |
\(Sol\)
\(O(n^2)\) 的做法是显然的,注意到异或的性质,\(a\) \(xor\) \(b\) \(xor\) \(a\) \(=\) \(b\)。又因为这题要求取到的区间是连续的,所以可以考虑对原数组做一个前缀异或和,不难想到所有答案都是从这个前缀异或和数组中任意取两个数异或起来得到的。需要注意的是需要在数组前端加入一个 \(0\),以此考虑到从 \(1\) 取到某个 \(r\) 的情况。
注意到 \(a_i\leq 4294967295\),事实上即为弄成二进制的 \(0\) 至 \(31\) 位,从高位到低位依次考虑每一位计算答案。
首先,我们想要拿到异或前 \(k\) 大的区间,则我们就尽量先选择 \(1\)。考虑这种情况下我们应该怎么选择,首先,最开始我们答案匹配的区间为 \(0\) 至 \(n\)(\(0\) 即为最开始补的 \(0\))。此时我们可以任意选择两个数,但是我们想让选择的这两个数异或得到的第 \(i\) 位为 \(1\),则显然和第 \(i\) 位为 \(0\) 与第 \(i\) 位为 \(1\) 的两种数有关。第 \(i\) 位为 \(1\) 的答案的个数即为第 \(i\) 位为 \(1\) 的数的个数乘上第 \(i\) 位为 \(0\) 的数的个数。但此时我们面临两个问题:如何快速找到这两个数量?由于我们确定了该位为 \(1\) 必须要求这两类数相乘,所以之后我们必须要求答案必须要从每一类数中选择一个。
比较幸运的是,我们发现,当我们将待选择的 \(n + 1\) 个数按大小排序过后,每一位的 \(0\) 与 \(1\) 的数在之前的每一位确定过后是连在一起的,所以我们可以排序过后二分找到那个临界点。同时,如果之前的操作要求我们必须从数集 \(A\) 以及数集 \(B\) 种各选择一个时,若我们选择 \(1\),则按照 \(A\) 数集与 \(B\) 数集取出的数哪一个第 \(i\) 位二进制为 \(1\) 有两种情况,且每种情况分裂过后又是一个新的数集 \(A\) 与 \(B\)。
则我们进行模拟。
每次考虑先在第 \(i\) 位填 \(1\),以及填 \(1\) 过后的数集以及对应的数集(同时可能不存在对应的数集,在当前数集中任意选择),然后进行讨论。不存在对应的数集当且仅当将在不存在对应数集的数集中选 \(0\) 时才会存在,即将 \(0\) 和 \(1\) 分开,在 \(0\) 和 \(1\) 范围内单独选择。其他情况都能够讨论,再把所有的数集存到下一轮即可。
注意剩余的待选择的数,不要多算。
时间复杂度比较玄学,考虑到 \(n\) 的范围,每一轮最多只会有 \(n\) 个区间,所以卡满是 \(O(2^{32}n)\),但这个好像不太容易卡满,而且剩的还挺多的,所以跑得比较快。
\(code\)
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 5e5 + 10;
inline int read()
{
int s = 0, w = 1;
char ch = getchar();
while(ch < '0' || ch > '9') { if(ch == '-') w *= -1; ch = getchar(); }
while(ch >= '0' && ch <= '9') s = s * 10 + ch - '0', ch = getchar();
return s * w;
}
struct node{ int l, r; }A[32][N], B[32][N];
int n, k, ans, cnt[32], top[32];
int a[N], to[32][N], pre[N], up[N];
inline int find(int x, int y, int bit) //查二进制的第 bit 位是 1 的分界线
{
int l = x, r = y, res = y + 1;
while(l <= r){
int mid = (l + r) >> 1;
if((pre[mid] >> bit) & 1) res = mid, r = mid - 1;
else l = mid + 1;
}
return res;
}
//bit 表示当前 dp 到第 i 位
inline void Sol(int bit, int v, int lv, int nex)
{
// cout << "Begin: " << bit << "\n";
//尝试选择 1
int inc = 0;
if(nex >= 0)top[nex] = cnt[nex] = 0;
for(register int i = 1; i <= top[bit]; i++){ //遍历 A 找对应 B
int lx = A[bit][i].l, rx = A[bit][i].r;
if(to[bit][i] == -1){ //不存在对应区间
int res = find(lx, rx, bit);
// cout << res << "\n";
if(rx - res + 1 && res - lx){
if(nex >= 0) A[nex][++top[nex]].l = lx, A[nex][top[nex]].r = res - 1, to[nex][top[nex]] = ++cnt[nex]; //分裂区间
if(nex >= 0) B[nex][cnt[nex]].l = res, B[nex][cnt[nex]].r = rx;
inc += (rx - res + 1) * (res - lx); //任意选择的情况的方案数
}
}
else{
int ly = B[bit][to[bit][i]].l, ry = B[bit][to[bit][i]].r;
int resx = find(lx, rx, bit), resy = find(ly, ry, bit);
int nlx0 = lx, nrx0 = resx - 1, nlx1 = resx, nrx1 = rx;
int nly0 = ly, nry0 = resy - 1, nly1 = resy, nry1 = ry;
if(nrx0 >= nlx0 && nry1 >= nly1){ //A0 B1
inc += (nrx0 - nlx0 + 1) * (nry1 - nly1 + 1); //算方案
if(nex >= 0) A[nex][++top[nex]].l = nlx0, A[nex][top[nex]].r = nrx0, to[nex][top[nex]] = ++cnt[nex];
if(nex >= 0) B[nex][cnt[nex]].l = nly1, B[nex][cnt[nex]].r = nry1;
}
if(nrx1 >= nlx1 && nry0 >= nly0){//A1 B0
inc += (nrx1 - nlx1 + 1) * (nry0 - nly0 + 1);
if(nex >= 0) A[nex][++top[nex]].l = nlx1, A[nex][top[nex]].r = nrx1, to[nex][top[nex]] = ++cnt[nex];
if(nex >= 0) B[nex][cnt[nex]].l = nly0, B[nex][cnt[nex]].r = nry0;
}
}
}
if(nex < 0) { ans = ans + (v | 1ll) * min(lv, inc); /*cout << ans << "\n";*/ } //已经到了最后一位,方案数 \times v 即可
if(inc && nex >= 0) Sol(bit - 1, v | (1ll << bit), lv, nex - 1); //向下一层
int tem = inc;
if(lv - inc <= 0) return;
if(nex >= 0) top[nex] = cnt[nex] = 0; //注意清空,即这一轮选 0 时下一轮需要满足的区间信息
inc = 0;
for(register int i = 1; i <= top[bit]; i++){
int lx = A[bit][i].l, rx = A[bit][i].r;
if(to[bit][i] == -1){ //没有对应区间
int res = find(lx, rx, bit);
if(nex >= 0 && lx <= res - 1) A[nex][++top[nex]].l = lx, A[nex][top[nex]].r = res - 1, to[nex][top[nex]] = -1; //更新得到的区间都不存在对应区间,因为只需要在自己内部选择即可
if(nex >= 0 && rx >= res) A[nex][++top[nex]].l = res, A[nex][top[nex]].r = rx, to[nex][top[nex]] = -1;
inc += (res - lx) * (res - lx - 1) / 2 + (rx - res + 1) * (rx - res) / 2;
}
else{ //同理讨论
int ly = B[bit][to[bit][i]].l, ry = B[bit][to[bit][i]].r;
int resx = find(lx, rx, bit), resy = find(ly, ry, bit);
int nlx0 = lx, nrx0 = resx - 1, nlx1 = resx, nrx1 = rx;
int nly0 = ly, nry0 = resy - 1, nly1 = resy, nry1 = ry;
if(nrx0 >= nlx0 && nry0 >= nly0){
inc += (nrx0 - nlx0 + 1) * (nry0 - nly0 + 1);
if(nex >= 0) A[nex][++top[nex]].l = nlx0, A[nex][top[nex]].r = nrx0, to[nex][top[nex]] = ++cnt[nex];
if(nex >= 0) B[nex][cnt[nex]].l = nly0, B[nex][cnt[nex]].r = nry0;
}
if(nrx1 >= nlx1 && nry1 >= nly1){
inc += (nrx1 - nlx1 + 1) * (nry1 - nly1 + 1);
if(nex >= 0) A[nex][++top[nex]].l = nlx1, A[nex][top[nex]].r = nrx1, to[nex][top[nex]] = ++cnt[nex];
if(nex >= 0) B[nex][cnt[nex]].l = nly1, B[nex][cnt[nex]].r = nry1;
}
}
}
if(nex < 0) { ans = ans + v * min(lv - tem, inc); } //统计答案
if(inc && nex >= 0) Sol(bit - 1, v, lv - tem, nex - 1);
}
signed main()
{
//freopen("data.in", "r", stdin);
//二进制到了 31 位,即 (1 << 32) - 1
n = read(), k = read();
for(register int i = 1; i <= n; i++) a[i] = read();
for(register int i = 1; i <= n; i++) pre[i] = pre[i - 1] ^ a[i];
sort(pre + 1, pre + n + 1); //排序
// for(register int i = 0; i <= n; i++) cout << pre[i] << "\n";
//此题显然是在所有的方案中选择权值从大到小靠前的 k 种,考虑数位 dp
A[31][++top[31]].l = 0, A[31][top[31]].r = n, to[31][top[31]] = -1;
Sol(31, 0, k, 30);
printf("%lld\n", ans);
return 0;
}