hdu 6059---Kanade's trio(字典树)
Problem Description
Give you an array A[1..n],you need to calculate how many tuples (i,j,k) satisfy that (i<j<k) and ((A[i] xor A[j])<(A[j] xor A[k]))
There are T test cases.
1≤T≤20
1≤∑n≤5∗105
0≤A[i]<230
There are T test cases.
1≤T≤20
1≤∑n≤5∗105
0≤A[i]<230
Input
There is only one integer T on first line.
For each test case , the first line consists of one integer n ,and the second line consists of n integers which means the array A[1..n]
For each test case , the first line consists of one integer n ,and the second line consists of n integers which means the array A[1..n]
Output
For each test case , output an integer , which means the answer.
Sample Input
1
5
1 2 3 4 5
Sample Output
6
题意:输入一个数列a[1]~a[n] ,求有多少个三元组(i,j,k) 满足1<=i<j<k<=n 且 a[i]异或a[j] < a[j]异或a[k]?
思路:对于a[i]与a[k],对于二进制从高位向低位进行判断,如果30位(A[i]<2^30)到25位相同,那么a[j]的这些位不管值是多少不影响异或后 a[i] 与 a[k] 的大小关系,现在第24位不同,那么a[j]的这一位必须和a[i]相同,这样a[k]异或a[j]的值一定大于a[i]异或a[j] ,第23位到第0位不管a[j]取何值不会影响大小关系了。 有上述可以得出我们只需要判断a[i]和a[k]的二进制最高不相同位就行,那么可以用一个二进制的字典树存储这n个数。
从a[i]~a[n]将a[k]插入字典树中,每次插入时需要记录 当前节点有多少数(num表示)、当前节点对应的a[j]有多少(count表示),用cn[32][2]记录第i位为0和1时的a[j]的个数,所以每次到一个节点时用count+=cn[i][1-t],表示当前的位(0或1),这样可以保证j<k,但是没有保证i<j ;
接下来将cn[][]清空,从a[1]~a[n]的进行删除,对于a[i]删除,可以保证i<k ,那么可以用count-num*cn[i][t] 保证i<j ;
代码如下:
#include <iostream> #include <algorithm> #include <cstdio> #include <cstring> using namespace std; typedef long long LL; const int N=5e5+5; int a[N],p[35],cn[32][2]; struct node { node *son[2]; int count; int num; node() { count=0; num=0; son[0]=son[1]=NULL; } }; node *root; void add(int x,int v) { node * now=root; for(int i=30;i>=0;i--) { int t=(!!(p[i]&x)); if(now->son[t]==NULL) now->son[t]=new node(); now=now->son[t]; now->num+=v; cn[i][t]++; now->count+=v*cn[i][1-t];///当前点对应的j的个数; } } LL cal(int x) { node * now=root; LL sum=0; for(int i=30;i>=0;i--) { int t=(!!(p[i]&x)); node* bro=now->son[1-t]; if(bro) sum+=bro->count - ((LL)bro->num*(LL)cn[i][t]); now=now->son[t]; if(!now) break; } return sum; } int main() { ///cout << "Hello world!" << endl; int T; cin>>T; p[0]=1; for(int i=1;i<32;i++) p[i]=p[i-1]<<1; while(T--) { int n; scanf("%d",&n); for(int i=1;i<=n;i++) scanf("%d",&a[i]); root=new node(); memset(cn,0,sizeof(cn)); for(int i=1;i<=n;i++) add(a[i],1); memset(cn,0,sizeof(cn)); LL ans=0; for(int i=1;i<n;i++){ add(a[i],-1); ans+=cal(a[i]); } printf("%lld\n",ans); } return 0; }