【luogu CF241B】Friends(Trie树)(二分)

Friends

题目链接:luogu CF241B

题目大意

给你一个序列,然后要你把序列里数两两异或得到的值从小到大排序,要你求前 k 大的值的和。

思路

首先考虑找到第 \(k\) 大是多大,因为异或嘛,自然想到 Trie 树。
于是就在 Trie 数上每个数放上去找位置,然后全部一起走可以做到 \(n\log n\),如果外面单独二分则是 \(n\log^2n\) 都可以。

这里补一句因为题目要的是 \(i<j\),这样不太好搞,不过 \(i=j\) 的时候是 \(0\),所以我们可以试着取消这个限制,发现就只会重复一倍的共贡献(当然你选的数量也要翻倍)

然后考虑怎么统计,也是考虑每一位逐个统计,为了方便我们把数组排序,这样连续的一段数在 Trie 数上最后投射的位置也是连续,于是可以每个子树一个 \(l,r\) 表示里面的有的点的区间。
然后就类似于在树上再走,然后把大于的部分都统计了。

统计大于的部分就弄 \(f_{i,j,k}\) 为第 \(i\) 位,前 \(k\) 个数中有多少个数这一位是 \(j\)
然后配合上 \(l,r\) 一个前缀和即可。
这部分复杂度是 \(n\log ^2n\) 的。

然后最后的等于直接单独处理即可。

代码

#include<cstdio>
#include<cstring> 
#include<iostream>
#include<algorithm>
#define ll long long
#define mo 1000000007 

using namespace std;

const int N = 5e4 + 100;
int n, a[N], tmp[N];
ll k, f[32][2][N];

struct Trie {
	struct node {
		int son[2], sz, l, r;
	}t[N * 32];
	int tot, knum;
	
	void insert(int x, int pla) {
		int now = 1;
		for (int i = 31; i >= 0; i--) {
			int to = (x >> i) & 1;
			if (!t[now].son[to]) t[now].son[to] = ++tot, t[tot].l = pla;
			now = t[now].son[to]; t[now].r = max(t[now].r, pla); t[now].sz++;
		}
	}
	
	int find(ll k) {
		int ans = 0; for (int i = 1; i <= n; i++) tmp[i] = 1;
		for (int i = 31; i >= 0; i--) {
			ll num = 0; int go = 0;
			for (int j = 1; j <= n; j++) num += t[t[tmp[j]].son[((a[j] >> i) & 1) ^ 1]].sz;
			if (num >= k) go = 1, ans |= (1 << i);
				else go = 0, k -= num;
			for (int j = 1; j <= n; j++) tmp[j] = t[tmp[j]].son[((a[j] >> i) & 1) ^ go];
		}
		knum = k;
		return ans;
	}
	
	ll clac_ans(ll k) {
		ll val = find(k); ll ans = knum * val % mo;
		int now = 0; for (int i = 1; i <= n; i++) tmp[i] = 1;
		for (int i = 31; i >= 0; i--) {
			int go = (val >> i) & 1;
			if (go == 0) {
				for (int j = 1; j <= n; j++) {
					int l = t[t[tmp[j]].son[((a[j] >> i) & 1) ^ 1]].l;
					int r = t[t[tmp[j]].son[((a[j] >> i) & 1) ^ 1]].r;
					for (int d = 31; d >= 0; d--)
						(ans += (1ll << d) * (f[d][((a[j] >> d) & 1) ^ 1][r] - f[d][((a[j] >> d) & 1) ^ 1][l - 1]) % mo) %= mo;
				}
			}
			for (int j = 1; j <= n; j++) tmp[j] = t[tmp[j]].son[((a[j] >> i) & 1) ^ go]; 
		}
		return ans;
	}
}T;

int main() {
	scanf("%d %lld", &n, &k); k <<= 1;//x-y y-x
	for (int i = 1; i <= n; i++) {
		scanf("%d", &a[i]);
	}
	sort(a + 1, a + n + 1);
	
	T.tot = 1; T.t[1].l = 1; T.t[1].r = n; 
	for (int i = 1; i <= n; i++)
		T.insert(a[i], i);
	for (int i = 0; i <= 31; i++)
		for (int j = 0; j <= 1; j++)
			for (int k = 1; k <= n; k++)
				f[i][j][k] = f[i][j][k - 1] + (((a[k] >> i) & 1) == j);
	
	printf("%lld", T.clac_ans(k) * (mo + 1) / 2 % mo);
	
	return 0;
}
posted @ 2022-09-21 18:56  あおいSakura  阅读(15)  评论(0编辑  收藏  举报