hdu6059( Trie )

hdu6059

题意

给定数组 \(A\) ,问有多少对下标 \((i, j, k)\) 满足 \(i < j < k\)\((A[i] \ xor \ A[j]) < (A[j] \ xor \ A[k])\)

分析

首先建一棵字典树,从高到低位插入所有数字(长度要相同,所以前面不足用 \(0\) 补),在插入的过程中计算对于每个 \(k\) 前面有多少个 \(j\) 可以配对(也就是在前面插入的值中寻找),只要将当前位取反就能找到有多少个 \(j\) 与之对应(可以用一个 \(cnt\) 数组记录每一位分别为 \(0\)\(1\) 的次数)。

查询时,从头开始删除,每删除一次(一是要去标记一下,而是去掉这个数作为 \(k\) 的影响),再去查询对应的数,我们求的实际是对于每个 \(A[i]\) 有几个 \(A[j] \ A[k]\) 与之对应 。在去计算答案的时候前面的标记就有作用了,已经标记作为 \(j\) 的与后面的 \(k\) 产生的配对要减掉,因为要满足 \(i < j\) ,前面标记过的 \(j < i\) ,所以前面的 \(j\)\(k\) 的配对是无效的。

大致的意思就是对 \(i\) 去寻找 \(k\) ,然后删掉不满足条件的 \(j\)

建议结合代码画图理解一下。

code

#include<bits/stdc++.h>
typedef long long ll;
using namespace std;
const int MAXN = 2e6 + 10;
int n;
int a[MAXN];
int root, L;
int nxt[MAXN][2], cnt[MAXN][2], has[MAXN];
ll sum[MAXN];
ll ans;
int newnode() {
    nxt[L][0] = nxt[L][1] = 0;
    return L++;
}
void init() {
    L = 1;
    root = newnode();
    memset(sum, 0, sizeof sum);
    memset(cnt, 0, sizeof cnt);
    memset(has, 0, sizeof has);
}
void insert(int x, int k) {
    int tp[32], c = 0;
    memset(tp, 0, sizeof tp);
    while(x) {
        tp[c++] = x % 2;
        x >>= 1;
    }
    int now = root;
    for(int i = 30; i >= 0; i--) {
        int d = tp[i];
        if(!nxt[now][d]) nxt[now][d] = newnode();
        now = nxt[now][d];
        cnt[i][d]++;
        sum[now] += k * cnt[i][d ^ 1];
        has[now] += k;
    }
}
void query(int x) {
    int tp[32], c = 0;
    memset(tp, 0, sizeof tp);
    while(x) {
        tp[c++] = x % 2;
        x >>= 1;
    }
    int now = root;
    for(int i = 30; i >= 0; i--) {
        int d = tp[i];
        int tmp = nxt[now][d ^ 1];
        if(tmp) {
            ans += sum[tmp] - 1LL * has[tmp] * cnt[i][d];
        }
        now = nxt[now][d];
        if(!now) break;
    }
}
int main() {
    int T;
    scanf("%d", &T);
    while(T--) {
        init();
        scanf("%d", &n);  
        for(int i = 0; i < n; i++) {
            scanf("%d", &a[i]);
            insert(a[i], 1);
        }
        ans = 0;
        memset(cnt, 0, sizeof cnt);
        for(int i = 0; i < n; i++) {
            insert(a[i], -1);
            query(a[i]);
        }
        printf("%lld\n", ans);
    }
    return 0;
}
posted @ 2017-08-02 22:28  ftae  阅读(109)  评论(0编辑  收藏  举报