BZOJ3689 - 异或之(trie)
题意
给定n个非负整数A[1], A[2], ……, A[n]。
对于每对(i, j)满足1 <= i < j <= n,得到一个新的数A[i] xor A[j],这样共有n*(n-1)/2个新的数。求这些数(不包含A[i])中前k小的数。
思路
一看到异或,就想到可能要用trie树来处理。
层数越深,两个数异或的结果越小。
所以可以从最底层往上处理,对trie树每一层的每个结点左子树和右子树包含的数暴力一一异或,一直到得到超过k个数。由于每上一层,增加的计算最多为2倍,所以复杂度是线性的。
得到的这些数就一定是排前的。最后排序输出k个数即可。
#include <bits/stdc++.h>
#define endl '\n'
#define IOS std::ios::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define FILE freopen(".//data_generator//in.txt","r",stdin),freopen("res.txt","w",stdout)
#define FI freopen(".//data_generator//in.txt","r",stdin)
#define FO freopen("res.txt","w",stdout)
#define pb push_back
#define mp make_pair
#define seteps(N) fixed << setprecision(N)
typedef long long ll;
using namespace std;
/*-----------------------------------------------------------------*/
ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
#define INF 0x3f3f3f3f
const int N = 2e6 + 10;
const double eps = 1e-5;
int tr[N][2];
int l[N], r[N];
vector<int> points[50];
vector<int> arr;
vector<int> ans;
int head[N];
int cnt[N];
int k;
int ct;
void insert(int x) {
int cur = 0;
for(int i = 30; i >= 0; i--) {
bool ntp = (x & (1 << i));
if(!tr[cur][ntp]) {
tr[cur][ntp] = ++ct;
cur = ct;
tr[cur][0] = tr[cur][1] = 0;
} else {
cur = tr[cur][ntp];
}
}
cnt[cur]++;
}
void dfs(int cur, int dep, int val) {
points[dep + 1].push_back(cur);
if(dep < 0) {
l[cur] = arr.size();
for(int i = 0 ;i < cnt[cur]; i++) {
arr.push_back(val);
}
r[cur] = arr.size() - 1;
return ;
}
if(tr[cur][0]) {
dfs(tr[cur][0], dep - 1, val);
l[cur] = l[tr[cur][0]];
r[cur] = r[tr[cur][0]];
}
if(tr[cur][1]) {
dfs(tr[cur][1], dep - 1, val + (1 << dep));
if(!tr[cur][0]) {
l[cur] = l[tr[cur][1]];
}
r[cur] = r[tr[cur][1]];
}
}
int main() {
//FILE;
IOS;
int n;
cin >> n >> k;
for(int i = 1; i <= n; i++) {
int v;
cin >> v;
insert(v);
}
dfs(0, 30, 0);
int tot = 0;
for(int p : points[0]) {
int c = cnt[p];
tot += c * (c - 1) / 2;
}
if(tot >= k) {
for(int i = 0; i < k; i++) {
cout << 0 << endl;
}
} else {
for(int i = 0; i < tot; i++) ans.push_back(0);
for(int i = 1; i <= 31; i++) {
for(int p : points[i]) {
int ls = tr[p][0], rs = tr[p][1];
if(ls && rs) {
for(int i = l[ls]; i <= r[ls]; i++) {
for(int j = l[rs]; j <= r[rs]; j++) {
ans.push_back(arr[i] ^ arr[j]);
}
}
}
}
if(ans.size() > k) break;
}
sort(ans.begin(), ans.end());
for(int i = 0; i < k; i++) {
cout << ans[i] << " ";
}
}
}