CF888G Xor-MST
这是作者对于这道典型例题思路的记录,用于作者自己本人对于字典树和异或的应用的加深印象,同时也欢迎大家的阅读。
这道题据说是有合理的最小生成树算法可以过(好像是\(Boruvka\)),但是没有试过,感觉是不可以过的。
这一道题的\(n<=2e5\),用普通的\(Prim\)和\(Kruskal\)绝对是会炸的,无论是时间还是空间,所以我们可以尝试从异或这一运算入手。
因为异或这一运算是相同的为\(0\),不同的为\(1\),所以我们可以尝试找出每个数字的最长前缀(算上前导\(0\)),说不定可以用后缀数组还是后缀自动机求\(lcp\)???
但是这里可以使用字典树,即建一棵有\(n\)个叶子节点的树使得每一条由根节点到叶子节点的链都代表着一个数字。然后我们再利用这棵优秀的二叉树(因为只有\(0\)和\(1\))来进行查询操作。
void add(int x)
{
int s[35];
for(int i=30;i>=1;--i)
{
s[i]=x&1;
x>>=1;
}
int tmp=0;
for(int i=1;i<=30;++i)
{
if(!tr[tmp].son[s[i]])
tr[tmp].son[s[i]]=++size;
tmp=tr[tmp].son[s[i]];
}
return ;
}
以上是建树部分的代码。就是建一棵很普通的字典树,记得补上前导\(0\)。
查询操作则尽量走相同的方向,因为我们是从高位走向低位的,所以我们要在尽量高的位置保持异或值为\(0\),实在不行时就走不同的路,再更新答案。
ll find(int u,int v,int dep)
{
if(!tr[u].son[0]&&!tr[u].son[1]&&!tr[v].son[0]&&!tr[v].son[1])
return 0;
ll res=2e9+7;
if(tr[u].son[0]&&tr[v].son[0])
res=min(res,find(tr[u].son[0],tr[v].son[0],dep-1));
if(tr[u].son[1]&&tr[v].son[1])
res=min(res,find(tr[u].son[1],tr[v].son[1],dep-1));
if(res<2e9+7)
return res;
if(tr[u].son[0]&&tr[v].son[1])
res=min(res,(1<<(dep-1))+find(tr[u].son[0],tr[v].son[1],dep-1));
if(tr[u].son[1]&&tr[v].son[0])
res=min(res,(1<<(dep-1))+find(tr[u].son[1],tr[v].son[0],dep-1));
return res;
}
ll ans=0;
void dfs(int u,int dep)
{
if(tr[u].son[0]&&tr[u].son[1])
ans+=(1<<(dep-1))+find(tr[u].son[0],tr[u].son[1],dep-1);
if(tr[u].son[0])
dfs(tr[u].son[0],dep-1);
if(tr[u].son[1])
dfs(tr[u].son[1],dep-1);
}
以上部分为查询操作,其中这棵二叉树只有\(n\)个叶子节点,所以在没有数字重复的情况下,肯定只有\(n-1\)个节点的儿子数为\(2\),所以直接\(dfs\)即可。
以下为完整代码:
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N=2e5+5;
int n;
int a[N],loc[N];
struct Node
{
int son[2];
}tr[N*30];
int size=0;
void add(int x)
{
int s[35];
for(int i=30;i>=1;--i)
{
s[i]=x&1;
x>>=1;
}
int tmp=0;
for(int i=1;i<=30;++i)
{
if(!tr[tmp].son[s[i]])
tr[tmp].son[s[i]]=++size;
tmp=tr[tmp].son[s[i]];
}
return ;
}
ll find(int u,int v,int dep)
{
if(!tr[u].son[0]&&!tr[u].son[1]&&!tr[v].son[0]&&!tr[v].son[1])
return 0;
ll res=2e9+7;
if(tr[u].son[0]&&tr[v].son[0])
res=min(res,find(tr[u].son[0],tr[v].son[0],dep-1));
if(tr[u].son[1]&&tr[v].son[1])
res=min(res,find(tr[u].son[1],tr[v].son[1],dep-1));
if(res<2e9+7)
return res;
if(tr[u].son[0]&&tr[v].son[1])
res=min(res,(1<<(dep-1))+find(tr[u].son[0],tr[v].son[1],dep-1));
if(tr[u].son[1]&&tr[v].son[0])
res=min(res,(1<<(dep-1))+find(tr[u].son[1],tr[v].son[0],dep-1));
return res;
}
ll ans=0;
void dfs(int u,int dep)
{
if(tr[u].son[0]&&tr[u].son[1])
ans+=(1<<(dep-1))+find(tr[u].son[0],tr[u].son[1],dep-1);
if(tr[u].son[0])
dfs(tr[u].son[0],dep-1);
if(tr[u].son[1])
dfs(tr[u].son[1],dep-1);
}
int main()
{
cin>>n;
for(int i=1;i<=n;++i)
scanf("%d",&a[i]);
sort(a+1,a+1+n);
n=unique(a+1,a+1+n)-a-1;
for(int i=1;i<=n;++i)
add(a[i]);
dfs(0,30);
printf("%lld\n",ans);
return 0;
}