hdu 6059 Kanade's trio(trie+容斥)

题目链接:hdu 6059 Kanade's trio

题意:

给你n个数,让你找有多少个(i,j,k),使得i<j<k满足a[i]^a[j]<a[j]^a[k]。

题解:

首先考虑a[i]和a[k],将他们都转换成二进制,对于a[i]和a[k],我们用Bi[p]表示二进制下的a[i]的第p位。考虑a[i]和a[k]二进制不同的最高位,这里假设为p,如果Bi[p]=0,Bk[p]=1,那么Bj[p]要为0,才能使得a[i]^a[j]<a[j]^a[k]。(因为p前面的位相同,只有亦或后的第p位k为1,i为0就行了)比如a[i]=5,a[k]=6,那么二进制下不同的最高为就是2,那么a[j]可以为0,1,5,13等,只要Bj[2]=0就行了。同理如果Bi[p]=1,Bk[p]=0,那么Bj[p]要为1。

知道这个性质后,现在我们就可以对每一个数进行枚举了。这里我们从前往后枚举每一个a[k],将1~(k-1)的数都全部插到trie树里面,并且开一个数组num[30][2]记录第p位为0和为1的数有多少个。

然后在插入过程中,在trie树里面找有多少个i和j,j的个数就是num[p][Bk[p]^1]],i的个数就是当前cnt[x][Bk[p]^1]](因为这样计算保证了a[i]和a[k]的二进制前缀相同,就该位不同),但是cnt[x][Bk[p]^1]]中又包括有j的个数,所以这里计算要注意一下,具体看代码。

这里有一部分i,j没有保证i<j,因为a[j]可能是在a[i]前面的数,所以在对于每一个新插入的数,都要保存一下这个数被当成a[i]时,有多少个a[j]在他前面,就是num[p][Bk[p]]]-cnt[x][Bk[p]],然后后面的数在计算时就要减掉。(具体的话举几个数模拟一下就知道了)

 1 #include<bits/stdc++.h>
 2 #define F(i,a,b) for(int i=a;i<=b;++i)
 3 using namespace std;
 4 
 5 const int M=5e5+7;
 6 int t,n,a[M],num[31][2],s[40];
 7 long long ans;
 8 
 9 struct Trie
10 {
11     static const int N=5e6+7,tyn=2;
12     int tr[N][tyn],cnt[N],tot;long long ext[N];
13     void nw(){cnt[++tot]=0,ext[tot]=0,memset(tr[tot],0,sizeof(tr[tot]));}
14     void init(){tot=-1,nw();}
15     void insert(int *s,int x=0){
16         for(int i=0,w;i<30;i++)
17         {
18             if(!tr[x][w=s[i]])nw(),tr[x][w]=tot;
19             if(tr[x][w^1])
20             {
21                 int nxt=tr[x][w^1];
22                 ans+=1ll*cnt[nxt]*(cnt[nxt]-1)>>1;
23                 ans+=1ll*(num[i][w^1]-cnt[nxt])*cnt[nxt]-ext[nxt];
24             }
25             x=tr[x][w],cnt[x]++,ext[x]+=num[i][w]-cnt[x];
26         }
27     }
28 }trie;
29 
30 int main()
31 {
32     scanf("%d",&t);
33     while(t--)
34     {
35         scanf("%d",&n);
36         F(i,1,n)scanf("%d",a+i);
37         trie.init(),ans=0;
38         F(i,1,n)
39         {
40             for(int j=29;j>=0;a[i]>>=1,j--)
41             {
42                 s[j]=a[i]&1;
43                 num[j][a[i]&1]++;
44             }
45             trie.insert(s);
46         }
47         printf("%lld\n",ans);
48     }
49     return 0;
50 }
View Code

 

posted @ 2017-08-02 17:14  bin_gege  阅读(173)  评论(0编辑  收藏  举报