hdu6059 Kanade's trio 字典树+容斥

转自:http://blog.csdn.net/dormousenone/article/details/76570172
/**
题目:hdu6059 Kanade's trio
链接:http://acm.hdu.edu.cn/showproblem.php?pid=6059
题意:含 N 个数字的 A 数组,求有多少个三元组 (i,j,k) 满足 i<j<k 且 (Ai⊕Aj)<(Aj⊕Ak)
思路:
利用字典树维护前 k-1 个数。当前处理第 k 个数。

显然对于 k 与 i 的最高不相同位 kp 与 ip :

当 ip=0 , kp=1 时,该最高不相同位之前的 ihigher=khigher 。则 jhigher 可以为任意数,

均不对 i, k 更高位(指最高不相同位之前的高位,后同)的比较产生影响。而此时 jp 位必须为 0 才可保证不等式 (Ai⊕Aj)<(Aj⊕Ak) 成立。

当 ip=1,kp=0 时,jp 位必须为 1 ,更高位任意。

故利用数组 cnt[31][2] 统计每一位为 0 ,为 1 的有多少个(在前 K-1 个数中)。

在字典树插入第 k 个数时,同时统计最高不相同位,即对于每次插入的 p 位为 num[p] (取值 0 或 1),

在同父节点对应的 1-num[p] 为根子树的所有节点均可作为 i 来寻找 j 以获取对答案的贡献。

其中又仅要求 jp 与 ip (ip 值即 1-num[p]) 相同,故 jp 有 cnt[p][ 1-num[p] ] 种取值方案。

但是,同时需要注意 i 与 j 有在 A 数组的先后关系 (i<j) 需要保证。

故在字典树中额外维护一个 Ext 点,记录将每次新加入的点与多少原有点可构成 i, j 关系。在后续计算贡献时去掉。
*/
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<map>
#include<vector>
#include<queue>
#include<cstring>
#include<cmath>
using namespace std;
typedef pair<int,int> P;
typedef long long LL;
const int INF = 0x3f3f3f3f;
const int maxnode = 5e5*30+10;
const int maxn = 5e5+100;
const int sigma_size = 2;
int cnt[maxn][30];
int ch[maxnode][sigma_size];
int sz;
int idx(char c){return c-'a';}
LL ans, ext;
struct node
{
    int cnt;
    int ext;
}val[maxnode];
int s[30];
void insert()
{
    int u = 0;
    for(int i = 29; i >= 0; i--){
        int c = s[i];
        if(ch[u][!c]){
            ans += (LL)val[ch[u][!c]].cnt*(val[ch[u][!c]].cnt-1)/2;
            ext += (LL)val[ch[u][!c]].cnt*(cnt[i][!c]-val[ch[u][!c]].cnt)-val[ch[u][!c]].ext;
        }
        if(!ch[u][c]){
            memset(ch[sz], 0, sizeof ch[sz]);
            val[sz].cnt = 0;
            val[sz].ext = 0;
            ch[u][c] = sz++;
        }
        u = ch[u][c];
        val[u].cnt++;
        val[u].ext += cnt[i][c]-val[u].cnt;
    }
}
int main()
{
    //freopen("C:\\Users\\accqx\\Desktop\\in.txt","r",stdin);
    int T, n;
    cin>>T;
    while(T--)
    {
        scanf("%d",&n);
        int x;
        sz = 1;
        memset(ch[0], 0, sizeof ch[0]);
        memset(cnt, 0, sizeof cnt);
        ans = ext = 0;
        for(int i = 1; i <= n; i++){
            scanf("%d",&x);
            for(int j = 0; j < 30; j++){
                cnt[j][x%2]++;
                s[j] = x%2;
                x/=2;
            }
            insert();
        }
        printf("%lld\n",ans+ext);
    }

    return 0;
}

 

posted on 2017-08-04 20:09  hnust_accqx  阅读(157)  评论(0编辑  收藏  举报

导航