Codeforces 888G Xor-MST (最小异或生成树)

给定 n 个点,每个点有个点权,任意两点可以连一条边,边权为两点权的异或值,求最小生成树

想法一:暴力求出所有边权,然后把边按边权从小到大排序,用kruskal跑最小生成树

想法二:把边排序后,发现最小的边权就是两个相同的值的异或值(为0),其次就是两个只在第 0 位不同的数的异或值 (为1)....

把01trie画出来,每个叶子节点就代表一个数,可以发现最小权值的边就是一个叶子节点和它本身(两个相同的数),其次就是lca距离叶子节点为一的两个叶子节点(这两个叶子节点只有第0位不同),再其次就是lca距离叶子节点距离为2的两个叶子节点...

于是,从最小的边权开始跑kruskal,肯定是在01trie上lca深度最大的两个叶子节点开始连边,在01tie上画图,可以发现对于每颗子树的根节点,一定是其左右子树上的叶子节点先构成了最小生成树后,再从左子树上选一叶子节点与右子树上选一叶子节点连通,于是就可以递归了

想法三:怎么计算连通左右子树的边的最小权值,可以枚举其中一颗子树上的每一个数,用O(log n)找最小异或数对的方法找与另一颗子树的所有数的异或最小值。其查找次数为O(n log n),所以复杂度为O(n logn logn)

#include<iostream>
#include<algorithm>
#include<cstdio>
using namespace std;
const int MAXN = 2e5+7;
const int INF = 1e9+7;
int a[MAXN];
int n;
int cnt = 0;
struct NODE{
	int ptp[2];
	int L,R,size;
}trie[MAXN*30];
void add(int x,int id){
	int now = 0;
	for(int i = 30;i>=0;i--){
		int t = 0;
		if(x & (1<<i)) t = 1;
		int &tt = trie[now].ptp[t];
		if(!tt) tt = ++cnt;
		now = tt;
		if(!trie[now].L) trie[now].L = id;
		trie[now].R = id;
		trie[now].size = trie[now].R - trie[now].L + 1;
	}
}
int cal(int s,int pos,int x){//s为根节点,s在pos+1位,找和x取0到pos位时的异或最小值 
	int now = s;
	int res = 0;
	for(int i = pos;i>=0;i--){
		int t = 0;
		if(x & (1<<i)) t = 1;
		if(trie[now].ptp[t]){
			now = trie[now].ptp[t];
		}
		else {
			now = trie[now].ptp[t^1];
			res |= 1<<i;
		}
	}
	return res;
}
long long solve(int s,int pos){//s在pos位 
	int x = trie[s].ptp[0],y = trie[s].ptp[1];
	if(x && y){//左右子树都存在 
		int res = INF;
		if(trie[x].size < trie[y].size){
			for(int i = trie[x].L;i <= trie[x].R;i++){//左子树小就枚举左子树 
				res = min(res,cal(y,pos-2,a[i]) + (1<<pos-1));//x,y在pos-1位 
			}
		}
		else{
			for(int i = trie[y].L;i <= trie[y].R;i++){//右子树小就枚举右子树 
				res = min(res,cal(x,pos-2,a[i]) + (1<<pos-1));
			}
		}
		return (long long)res + solve(x,pos-1) + solve(y,pos-1);
	}
	else if(x){
		//只有左子树,递归 
		return solve(x,pos-1);
	}
	else if(y){
		return solve(y,pos-1);
	}
	return 0;
}
int main()
{
	cin>>n;
	for(int i = 1;i <= n;i++) scanf("%d",&a[i]);
	sort(a+1,a+n+1);
	for(int i = 1;i <= n;i++) add(a[i],i);
	long long ans = solve(0,31);
	printf("%lld\n",ans);
	return 0;
} 

  

 

posted @ 2021-10-26 23:10  beta_dust  阅读(135)  评论(0编辑  收藏  举报