cf 888G - Xor-MST(01字典树+分治)

题目链接:传送门

题目思路:

对于ai ,找到一个aj 满足其二进制的公共前缀最长(公共前缀越长,异或值越小),如果aj有多个,那么再枚举判断和谁连边是最优的。

对于本题的做法,可以采用针对第k位的0/1进行分治(把区间按第k位的分成两个子区间),这样能保证每个ai会和另一个公共前缀最长的aj连边。回溯时,是两个连通块(单个点也可视作连通块)进行合并(两个连通块的公共前缀也是最长的),如果直接暴力合并,那么每一层分治的总复杂度是O(n2) 的,显然不行, 这里可以把一个连通块的所有点都放到一棵字典树里(字典树维护二进制),再用另一个连通块的点依次在字典树里查询(从高位到低位,尽量沿着与ai 的某一位相同的 0/1 方向查找),详细可以参考代码。

代码:

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 typedef long long LL;
 4 typedef unsigned long long uLL;
 5 typedef pair<int,int> pii;
 6 typedef pair<LL,LL> pLL;
 7 typedef pair<double,double> pdd;
 8 const int N=2e5+5;
 9 const int M=1e7+5;
10 const int inf=0x3f3f3f3f;
11 const LL mod=1e9+7;
12 const double eps=1e-8;
13 const long double pi=acos(-1.0L);
14 #define ls (i<<1)
15 #define rs (i<<1|1)
16 #define fi first
17 #define se second
18 #define pb push_back
19 #define eb emplace_back
20 #define mk make_pair
21 #define mem(a,b) memset(a,b,sizeof(a))
22 LL read()
23 {
24     LL x=0,t=1;
25     char ch;
26     while(!isdigit(ch=getchar())) if(ch=='-') t=-1;
27     while(isdigit(ch)){ x=10*x+ch-'0'; ch=getchar(); }
28     return x*t;
29 }
30 int bt[N<<5][2],a[N],cnt;
31 LL ans;
32 void ins(int x)
33 {
34     int p=1;
35     for(int i=29;i>=0;i--)
36     {
37         int t=((1<<i)&x)>0;
38         if(!bt[p][t]) bt[p][t]=++cnt,bt[cnt][0]=bt[cnt][1]=0;
39         p=bt[p][t];
40     }
41 }
42 void cal(int l,int mid,int r)
43 {
44     bt[cnt=1][0]=bt[1][1]=0;
45     for(int i=l;i<mid;i++) ins(a[i]);
46     LL res=1e12;
47     for(int i=mid;i<=r;i++)
48     {
49         int p=1;
50         LL sum=0;
51         for(int j=29;j>=0;j--)
52         {
53             int x=(a[i]>>j)&1;
54             if(!bt[p][x]) sum+=1<<j,x^=1;
55             p=bt[p][x];
56         }
57         res=min(res,sum);
58     }
59     ans+=res;
60 }
61 void dfs(int l,int r,int k)
62 {
63     if(k==-1) return ;
64    // printf("%d , %d , %d\n",l,r,k);
65     int mid=0;
66     for(int i=l;i<=r&&!mid;i++)
67         if((1<<k)&a[i]) mid=i;
68     if(mid==0) mid=r+1;
69     if(l<mid) dfs(l,mid-1,k-1);
70     if(mid<=r) dfs(mid,r,k-1);
71     if(l<mid&&mid<=r) cal(l,mid,r);
72 }
73 int main()
74 {
75     int n=read();
76     for(int i=1;i<=n;i++) a[i]=read();
77     sort(a+1,a+n+1);
78     dfs(1,n,29);
79     printf("%lld\n",ans);
80     return 0;
81 }
82 /*
83 4
84 1 2 3 4
85 */
View Code

 

posted @ 2020-11-17 18:56  DeepJay  阅读(123)  评论(0编辑  收藏  举报