hdu6059( Trie )
hdu6059
题意
给定数组 \(A\) ,问有多少对下标 \((i, j, k)\) 满足 \(i < j < k\) 且 \((A[i] \ xor \ A[j]) < (A[j] \ xor \ A[k])\) 。
分析
首先建一棵字典树,从高到低位插入所有数字(长度要相同,所以前面不足用 \(0\) 补),在插入的过程中计算对于每个 \(k\) 前面有多少个 \(j\) 可以配对(也就是在前面插入的值中寻找),只要将当前位取反就能找到有多少个 \(j\) 与之对应(可以用一个 \(cnt\) 数组记录每一位分别为 \(0\) 和 \(1\) 的次数)。
查询时,从头开始删除,每删除一次(一是要去标记一下,而是去掉这个数作为 \(k\) 的影响),再去查询对应的数,我们求的实际是对于每个 \(A[i]\) 有几个 \(A[j] \ A[k]\) 与之对应 。在去计算答案的时候前面的标记就有作用了,已经标记作为 \(j\) 的与后面的 \(k\) 产生的配对要减掉,因为要满足 \(i < j\) ,前面标记过的 \(j < i\) ,所以前面的 \(j\) 与 \(k\) 的配对是无效的。
大致的意思就是对 \(i\) 去寻找 \(k\) ,然后删掉不满足条件的 \(j\) 。
建议结合代码画图理解一下。
code
#include<bits/stdc++.h>
typedef long long ll;
using namespace std;
const int MAXN = 2e6 + 10;
int n;
int a[MAXN];
int root, L;
int nxt[MAXN][2], cnt[MAXN][2], has[MAXN];
ll sum[MAXN];
ll ans;
int newnode() {
nxt[L][0] = nxt[L][1] = 0;
return L++;
}
void init() {
L = 1;
root = newnode();
memset(sum, 0, sizeof sum);
memset(cnt, 0, sizeof cnt);
memset(has, 0, sizeof has);
}
void insert(int x, int k) {
int tp[32], c = 0;
memset(tp, 0, sizeof tp);
while(x) {
tp[c++] = x % 2;
x >>= 1;
}
int now = root;
for(int i = 30; i >= 0; i--) {
int d = tp[i];
if(!nxt[now][d]) nxt[now][d] = newnode();
now = nxt[now][d];
cnt[i][d]++;
sum[now] += k * cnt[i][d ^ 1];
has[now] += k;
}
}
void query(int x) {
int tp[32], c = 0;
memset(tp, 0, sizeof tp);
while(x) {
tp[c++] = x % 2;
x >>= 1;
}
int now = root;
for(int i = 30; i >= 0; i--) {
int d = tp[i];
int tmp = nxt[now][d ^ 1];
if(tmp) {
ans += sum[tmp] - 1LL * has[tmp] * cnt[i][d];
}
now = nxt[now][d];
if(!now) break;
}
}
int main() {
int T;
scanf("%d", &T);
while(T--) {
init();
scanf("%d", &n);
for(int i = 0; i < n; i++) {
scanf("%d", &a[i]);
insert(a[i], 1);
}
ans = 0;
memset(cnt, 0, sizeof cnt);
for(int i = 0; i < n; i++) {
insert(a[i], -1);
query(a[i]);
}
printf("%lld\n", ans);
}
return 0;
}