CF241B Friends 题解
是 异或粽子 的超级加强版,但是本题因为 \(m\) 很大,不能套用那一题的解法。
换一种思路:考虑把 \(a_i\) 从高位到低位插入 0-1 Trie 之后,二分第 \(m\) 大,通过第 \(m\) 大求答案。
对于二分的一个值 \(x\),枚举每个位置 \(i\),在 0-1 Trie 上找与 \(a_i\) 异或值大于等于 \(x\) 的值的个数。
类比求最大异或和的过程,考虑搜索到第 \(j\) 位。如果 \(x\) 的第 \(j\) 位为 \(1\),为了最终异或值大于等于 \(x\),可能的数一定在与 \(a_i\) 的第 \(j\) 位相异的子树中,递归即可;反之,如果 \(x\) 的第 \(j\) 位为 \(0\),与 \(a_i\) 的第 \(j\) 位相异的子树中的值一定全部满足条件,递归与 \(a_i\) 的第 \(j\) 位相同的子树即可。
于是可以在 \(O(n \log^2 w)\) 的时间内找到第 \(m\) 大的异或值(\(w\) 为值域,后文同),设这个值为 \(k\)。
下一步是求前 \(m\) 大两两异或值的和。
容易想到,类似前文所述,只需要处理被计算的完整的子树与 \(a_i\) 的异或值的和。(即搜索时找到的 \(k\) 的第 \(j\) 位为 \(0\),与 \(a_i\) 的第 \(j\) 位相异的子树)直接对这些子树的根节点打标记,整体遍历一次 0-1 Trie 时容易得到这棵子树内每一位上 \(0\) 和 \(1\) 的数量,答案也就容易统计了。
至多有 \(n \log w\) 个标记,处理每个标记需要枚举 \(\log w\) 位。同时,至多合并 \(O(n \log w)\) 次,单次合并的时间为 \(O(\log w)\)。综上,时间复杂度 \(O(n \log^2 w)\)。
将两部分拼起来就得到了最终做法,时间复杂度 \(O(n \log^2 w)\),可以通过。
#include <iostream>
#include <map>
#include <vector>
using namespace std;
typedef long long ll;
const ll mod = 1e9 + 7;
static inline ll qpow(ll a, ll b) {
ll ret = 1;
while (b) {
if (b & 1)
ret = ret * a % mod;
a = a * a % mod;
b >>= 1;
}
return ret;
}
const ll _2 = qpow(2, mod - 2);
int n, m;
int a[50005];
int rt;
map<int, vector<int>> mp;
struct trie {
int son[2];
int val, cnt[32];
} tr[1600005];
int cnt;
static inline int insert(int x, int dep, int p) {
if (!p)
p = ++cnt;
if (dep == -1) {
++tr[p].val;
return p;
}
int tag = (x >> dep) & 1;
tr[p].son[tag] = insert(x, dep - 1, tr[p].son[tag]);
tr[p].val = tr[tr[p].son[0]].val + tr[tr[p].son[1]].val;
return p;
}
static inline ll query(int x, int aim, int dep, int p) {
if (!p)
return 0;
if (dep == -1)
return tr[p].val;
int tg = (x >> dep) & 1;
int tg2 = (aim >> dep) & 1;
if (tg2)
return query(x, aim, dep - 1, tr[p].son[tg ^ 1]);
return tr[tr[p].son[tg ^ 1]].val + query(x, aim, dep - 1, tr[p].son[tg]);
}
ll ans, sum;
static inline void query2(int x, int aim, int dep, int p) {
if (!p)
return;
if (dep == -1)
return;
int tg = (x >> dep) & 1;
int tg2 = (aim >> dep) & 1;
if (tg2)
return query2(x, aim, dep - 1, tr[p].son[tg ^ 1]);
if (tr[p].son[tg ^ 1])
mp[tr[p].son[tg ^ 1]].push_back(x);
query2(x, aim, dep - 1, tr[p].son[tg]);
}
static inline void dfs(int dep, int p, int from, int val) {
if (!p)
return;
if (dep == -1) {
if (from)
tr[p].cnt[dep + 1] += tr[p].val;
if (mp.count(p))
for (auto x : mp[p]) {
sum += tr[p].val;
ll tot = 0;
for (int i = 0; i <= dep + 1; ++i) {
int tg = (x >> i) & 1;
if (tg)
tot += (ll)(tr[p].val - tr[p].cnt[i]) * (1ll << i) % mod;
else
tot += (ll)tr[p].cnt[i] * (1ll << i) % mod;
}
for (int i = dep + 2; i <= 30; ++i) {
int tg = (x >> i) & 1;
int tg2 = (val >> i) & 1;
if (tg ^ tg2)
tot += (ll)tr[p].val * (1ll << i) % mod;
}
ans = (ans + tot) % mod;
}
return;
}
if (tr[p].son[0]) {
dfs(dep - 1, tr[p].son[0], 0, val);
for (int i = 0; i <= 30; ++i)
tr[p].cnt[i] += tr[tr[p].son[0]].cnt[i];
}
if (tr[p].son[1]) {
dfs(dep - 1, tr[p].son[1], 1, val | (1 << dep));
for (int i = 0; i <= 30; ++i)
tr[p].cnt[i] += tr[tr[p].son[1]].cnt[i];
}
if (from)
tr[p].cnt[dep + 1] += tr[p].val;
if (mp.count(p))
for (auto x : mp[p]) {
sum += tr[p].val;
ll tot = 0;
for (int i = 0; i <= dep + 1; ++i) {
int tg = (x >> i) & 1;
if (tg)
tot += (ll)(tr[p].val - tr[p].cnt[i]) * (1ll << i) % mod;
else
tot += (ll)tr[p].cnt[i] * (1ll << i) % mod;
}
for (int i = dep + 2; i <= 30; ++i) {
int tg = (x >> i) & 1;
int tg2 = (val >> i) & 1;
if (tg ^ tg2)
tot += (ll)tr[p].val * (1ll << i) % mod;
}
ans = (ans + tot) % mod;
}
}
static inline bool check(int x) {
ll tot = 0;
for (int i = 1; i <= n; ++i)
tot += query(a[i], x, 30, rt);
tot >>= 1;
return tot >= m;
}
static inline void solve() {
cin >> n >> m;
if (m == 0) {
cout << 0 << endl;
return;
}
for (int i = 1; i <= n; ++i) {
cin >> a[i];
rt = insert(a[i], 30, rt);
}
int l = 1;
int r = 1ll << 30;
int ret = -1;
while (l <= r) {
int mid = (int)((l + r) >> 1);
if (check(mid)) {
ret = mid;
l = mid + 1;
} else {
r = mid - 1;
}
}
for (int i = 1; i <= n; ++i)
query2(a[i], ret, 30, rt);
dfs(30, rt, 0, 0);
sum >>= 1;
ans = ans * _2 % mod;
ans = (ans + (m - sum) * ret % mod) % mod;
cout << ans << endl;
}
signed main() {
#ifndef ONLINE_JUDGE
freopen("1.in", "r", stdin);
#endif
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
solve();
return 0;
}