CF241B Friends
知识点:01 Trie
「十二省联考 2019」异或粽子 是本题的另一数据范围版本。
简述
给定一长度为 \(n\) 的数列 \(a\),求两两异或值前 \(k\) 大的和,答案对 \(10^9 + 7\) 取模。
\(1\le n\le 5\times 10^4\),\(0\le k\le \frac{n(n-1)}{2}\),\(0\le a_i\le 10^9\)。
6S,256MB。
分析
两两异或和是有序对,考虑令 \(k\) 翻倍,再使最终答案除 2,以除去有序对的限制。显然这样可以保证原问题中前 \(k\) 大元素都会被统计到两次,且根据异或的自反性,不合法的满足 \(l = r\) 数对贡献为 0,不会影响答案。
要求前 \(k\) 大异或值的和,考虑先求出第 \(k\) 的的异或值。显然可以二分解决。先对 \(a\) 建立 Trie 并在 Trie 上维护 \(\operatorname{size}\),之后二分答案,再枚举所有 \(a_i\),求得有多少个异或值 \(a_i\oplus a_j\) 不小于二分量,检查二分量是否是第 \(k\) 大即可。此过程可以在 Trie 上检查二分量各二进制与 \(a_i\) 各二进制的情况来解决,代码如下所示:
namespace Trie {
LL QueryCnt(int now_, int lth_, LL val_, LL lim_) {
if (!now_ && lth_ != 30) return 0; //无贡献的情况
if (lth_ < 0) return siz[now_]; //深入到底,此时有 val[now_] = lim_
int ch = 1ll * val_ >> lth_ & 1ll;
if (lim_ >> lth_ & 1) { //lim_ 这一位是 1,若异或值不小于大于 lim_,则异或值此位必须为 1。
return QueryCnt(tr[now_][ch ^ 1], lth_ - 1, val_, lim_);
} else { //lim_ 这一位是 0,这一位是 1 的异或值一定比 lim_ 大,统计贡献 size,并统计异或值此位为 0 的贡献。
return siz[tr[now_][ch ^ 1]] +
QueryCnt(tr[now_][ch], lth_ - 1, val_, lim_);
}
}
}
bool Check(LL mid_) {
LL ret = 0;
//QueryCnt 能返回所有异或值 a_i xor ? 不小于 mid 的数的个数
for (int i = 1; i <= n; ++ i) ret += Trie::QueryCnt(0, 30, a[i], mid_);
return k <= ret;
}
之后考虑求得所有不小于第 \(k\) 大值的异或值的和。仍考虑先枚举再枚举所有 \(a_i\),在 Trie 上检查二分量各二进制与 \(a_i\) 各二进制的情况,得到所有异或值。
观察上面代码可知,当第 \(k\) 大值的某一位是 0 时,说明这一位是 1 的状态为根的子树中所有 \(a_j\) 与 \(a_i\) 的异或值大于第 \(k\) 大值,它们对答案有贡献。问题变为如何快速统计 Trie 的某棵子树中所有数与给定值的异或值的和。
这里有一个小 trick,先对 \(a\) 进行排序,再按照顺序把它们插入 Trie 中,这能使得 Trie 一棵子树中的数是原数列中一段连续的子段。其原理显然,可以类比基数排序。
数列与给定值的异或值的和很好做,考虑拆位,对每个数进行二进制分解,用前缀和分别维护每一位上 0/1 的个数。查询时取出两区间中各位 0/1 的个数,考察与每位不同的数的对数,统计贡献即可。
考虑复杂度,第一部分二分答案 + Trie 上贪心复杂度为 \(O(n\log^2 w)\),第二部分统计贡献 Trie 上贪心 + 拆位统计贡献复杂度也为 \(O(n\log^2 w)\)。其中 \(w = \max\{a_i\} \approx 30\)。
代码
//知识点:01 Trie,贪心
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#define LL long long
const int kN = 5e4 + 10;
const LL mod = 1e9 + 7;
const LL inv2 = 500000004;
//=============================================================
LL n, k, kth, ans, a[kN], sum[kN][32];
//=============================================================
inline LL read() {
LL f = 1, w = 0;
char ch = getchar();
for (; !isdigit(ch); ch = getchar())
if (ch == '-') f = -1;
for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
return f * w;
}
void Chkmax(int &fir, int sec) {
if (sec > fir) fir = sec;
}
void Chkmin(int &fir, int sec) {
if (sec < fir) fir = sec;
}
LL Calc(int l_, int r_, int val_) {
if (!l_ && !r_) return 0;
LL ret = 0;
for (int i = 30; i >= 0; -- i) {
LL cnt[2] = {0};
cnt[1] = sum[r_][i] - sum[l_ - 1][i], cnt[0] = r_ - l_ + 1 - cnt[1];
ret = (ret + (1ll << i) % mod * cnt[(val_ >> i & 1) ^ 1] % mod) % mod;
}
return ret;
}
namespace Trie {
const int kMaxNode = kN << 7;
int node_num, tr[kMaxNode][2];
int l[kMaxNode], r[kMaxNode];
LL siz[kMaxNode];
void Insert(int val_, int id_) {
int now_ = 0;
for (int i = 30; i >= 0; -- i) {
int ch = 1ll * val_ >> i & 1ll;
if (!tr[now_][ch]) {
tr[now_][ch] = ++ node_num;
l[node_num] = id_;
}
now_ = tr[now_][ch];
r[now_] = id_;
siz[now_] ++;
}
}
LL QueryCnt(int now_, int lth_, LL val_, LL lim_) {
if (!now_ && lth_ != 30) return 0;
if (lth_ < 0) return siz[now_];
int ch = 1ll * val_ >> lth_ & 1ll;
if (lim_ >> lth_ & 1) {
return QueryCnt(tr[now_][ch ^ 1], lth_ - 1, val_, lim_);
} else {
return siz[tr[now_][ch ^ 1]] +
QueryCnt(tr[now_][ch], lth_ - 1, val_, lim_);
}
}
LL QuerySum(int now_, int lth_, LL val_, LL lim_) {
if (!now_ && lth_ != 30) return 0;
if (lth_ < 0) return Calc(l[now_], r[now_], val_);
int ch = 1ll * val_ >> lth_ & 1ll;
if (lim_ >> lth_ & 1) {
return QuerySum(tr[now_][ch ^ 1], lth_ - 1, val_, lim_);
} else {
return (Calc(l[tr[now_][ch ^ 1]], r[tr[now_][ch ^ 1]], val_) +
QuerySum(tr[now_][ch], lth_ - 1, val_, lim_)) % mod;
}
}
}
void Init() {
n = read(), k = read() << 1ll;
for (int i = 1; i <= n; ++ i) a[i] = read();
std::sort(a + 1, a + n + 1);
for (int i = 1; i <= n; ++ i) {
Trie::Insert(a[i], i);
for (int j = 30; j >= 0; -- j) {
sum[i][j] = sum[i - 1][j] + (a[i] >> j & 1);
}
}
}
bool Check(LL mid_) {
LL ret = 0;
//返回大于 mid 的数的个数
for (int i = 1; i <= n; ++ i) ret += Trie::QueryCnt(0, 30, a[i], mid_);
return k <= ret;
}
//=============================================================
int main() {
Init();
// Check(1);
for (LL l = 0, r = 2e9; l <= r; ) {
int mid = (l + r) >> 1;
if (Check(mid)) {
kth = mid;
l = mid + 1;
} else {
r = mid - 1;
}
}
for (int i = 1; i <= n; ++ i) {
ans = (ans + Trie::QuerySum(0, 30, a[i], kth + 1)) % mod;
k -= Trie::QueryCnt(0, 30, a[i], kth + 1);
}
ans = 1ll * (ans + 1ll * k * kth % mod) % mod * inv2 % mod;
printf("%lld\n", ans);
return 0;
}