【luogu CF241B】Friends(Trie树)(二分)
Friends
题目链接:luogu CF241B
题目大意
给你一个序列,然后要你把序列里数两两异或得到的值从小到大排序,要你求前 k 大的值的和。
思路
首先考虑找到第 \(k\) 大是多大,因为异或嘛,自然想到 Trie 树。
于是就在 Trie 数上每个数放上去找位置,然后全部一起走可以做到 \(n\log n\),如果外面单独二分则是 \(n\log^2n\) 都可以。
这里补一句因为题目要的是 \(i<j\),这样不太好搞,不过 \(i=j\) 的时候是 \(0\),所以我们可以试着取消这个限制,发现就只会重复一倍的共贡献(当然你选的数量也要翻倍)
然后考虑怎么统计,也是考虑每一位逐个统计,为了方便我们把数组排序,这样连续的一段数在 Trie 数上最后投射的位置也是连续,于是可以每个子树一个 \(l,r\) 表示里面的有的点的区间。
然后就类似于在树上再走,然后把大于的部分都统计了。
统计大于的部分就弄 \(f_{i,j,k}\) 为第 \(i\) 位,前 \(k\) 个数中有多少个数这一位是 \(j\)。
然后配合上 \(l,r\) 一个前缀和即可。
这部分复杂度是 \(n\log ^2n\) 的。
然后最后的等于直接单独处理即可。
代码
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
#define mo 1000000007
using namespace std;
const int N = 5e4 + 100;
int n, a[N], tmp[N];
ll k, f[32][2][N];
struct Trie {
struct node {
int son[2], sz, l, r;
}t[N * 32];
int tot, knum;
void insert(int x, int pla) {
int now = 1;
for (int i = 31; i >= 0; i--) {
int to = (x >> i) & 1;
if (!t[now].son[to]) t[now].son[to] = ++tot, t[tot].l = pla;
now = t[now].son[to]; t[now].r = max(t[now].r, pla); t[now].sz++;
}
}
int find(ll k) {
int ans = 0; for (int i = 1; i <= n; i++) tmp[i] = 1;
for (int i = 31; i >= 0; i--) {
ll num = 0; int go = 0;
for (int j = 1; j <= n; j++) num += t[t[tmp[j]].son[((a[j] >> i) & 1) ^ 1]].sz;
if (num >= k) go = 1, ans |= (1 << i);
else go = 0, k -= num;
for (int j = 1; j <= n; j++) tmp[j] = t[tmp[j]].son[((a[j] >> i) & 1) ^ go];
}
knum = k;
return ans;
}
ll clac_ans(ll k) {
ll val = find(k); ll ans = knum * val % mo;
int now = 0; for (int i = 1; i <= n; i++) tmp[i] = 1;
for (int i = 31; i >= 0; i--) {
int go = (val >> i) & 1;
if (go == 0) {
for (int j = 1; j <= n; j++) {
int l = t[t[tmp[j]].son[((a[j] >> i) & 1) ^ 1]].l;
int r = t[t[tmp[j]].son[((a[j] >> i) & 1) ^ 1]].r;
for (int d = 31; d >= 0; d--)
(ans += (1ll << d) * (f[d][((a[j] >> d) & 1) ^ 1][r] - f[d][((a[j] >> d) & 1) ^ 1][l - 1]) % mo) %= mo;
}
}
for (int j = 1; j <= n; j++) tmp[j] = t[tmp[j]].son[((a[j] >> i) & 1) ^ go];
}
return ans;
}
}T;
int main() {
scanf("%d %lld", &n, &k); k <<= 1;//x-y y-x
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
}
sort(a + 1, a + n + 1);
T.tot = 1; T.t[1].l = 1; T.t[1].r = n;
for (int i = 1; i <= n; i++)
T.insert(a[i], i);
for (int i = 0; i <= 31; i++)
for (int j = 0; j <= 1; j++)
for (int k = 1; k <= n; k++)
f[i][j][k] = f[i][j][k - 1] + (((a[k] >> i) & 1) == j);
printf("%lld", T.clac_ans(k) * (mo + 1) / 2 % mo);
return 0;
}