BSOJ6325【10.17题目】异或doubt
题目
分析
21.10.20考试T1。
首先看到字典序,容易想到按位贪心,然后看到异或容易想到trie/线性基。
于是考虑一个naive的trie的做法,对A建立trie,然后对于每一个B查询异或最小值,全部放进一个小根堆,每次取出堆顶作为当前的c,并且注意重新插入。
这个东西最坏显然是 \(O(n^2logV)\) 的,暂时还不会卡。
实际测试的话最低可以得80分,实现好了甚至有100分。
接下来考虑正解:
显然一个一个按位贪心的做法非常naive,没办法优化,于是考虑整体来做。
对A和B一起建立一个trie,但是 \(sum\) 值分别保存,记作 \(sum1\) 和 \(sum2\) 。
然后考虑直接 \(dfs\) :
对于当前遍历到的两个节点,一个代表A,一个代表B。
我们尽量合并两棵树的同侧,于是就先合并A的左儿子和B的左儿子,以及A的右儿子和B的右儿子(递归下去合并)。
如果还有剩余,也就是合并不完的话,就合并A的右儿子和B的左儿子,以及A的左儿子和B的右儿子。
最后我们对于求出来的c数组排个序即可。
然后就结束了,具体见代码。
代码
#include<bits/stdc++.h>
using namespace std;
template<typename T>
inline void read(T &x){
x=0;bool f=false;char ch=getchar();
while(!isdigit(ch)){f|=ch=='-';ch=getchar();}
while(isdigit(ch)){x=x*10+(ch^48);ch=getchar();}
x=f?-x:x;
return;
}
const int N=2e5+5,M=2e5+5,MOD=1e9+7,INF=1e9+7;
int B=30;
int n,m,a[N],b[N],rt=1,cur=1,c[N],cnt;
struct trie{
int ch[2],sum[2];
#define ch(x,i) t[x].ch[i]
#define sum(x,i) t[x].sum[i]
}t[N*60];
void Insert(int x,int tp){
int now=rt;sum(rt,tp)++;
for(int i=(1<<B);i;i>>=1){
int v=(x&i);v=v?1:0;
if(!ch(now,v)) ch(now,v)=++cur;
now=ch(now,v);
sum(now,tp)++;
}
return ;
}
int dfs(int u,int v,int dep,int now1,int now2){
//printf("u:%d v:%d dep:%d now1:%d now2:%d sum1:%d sum2:%d\n",u,v,dep,now1,now2,sum(u,0),sum(v,1));
if(!u||!v||sum(u,0)<=0||sum(v,1)<=0) return 0;
if(dep==-1){
int len=min(sum(u,0),sum(v,1));
for(int i=1;i<=len;i++) c[++cnt]=(now1^now2);
sum(u,0)-=len,sum(v,1)-=len;
return len;
}
int tmp=0;
if(ch(u,0)&&ch(v,0)) tmp+=dfs(ch(u,0),ch(v,0),dep-1,now1,now2);
if(ch(u,1)&&ch(v,1)) tmp+=dfs(ch(u,1),ch(v,1),dep-1,now1|(1<<dep),now2|(1<<dep));
if(ch(u,0)&&ch(v,1)) tmp+=dfs(ch(u,0),ch(v,1),dep-1,now1,now2|(1<<dep));
if(ch(u,1)&&ch(v,0)) tmp+=dfs(ch(u,1),ch(v,0),dep-1,now1|(1<<dep),now2);
sum(u,0)-=tmp,sum(v,1)-=tmp;
return min(sum(u,0),sum(v,1))+tmp;
}
signed main(){
// system("fc doubt.out doubt3.ans");
// freopen("doubt.in","r",stdin);
// freopen("doubt.out","w",stdout);
read(n);
for(int i=1;i<=n;i++) read(a[i]),Insert(a[i],0);
for(int i=1;i<=n;i++) read(b[i]),Insert(b[i],1);
dfs(rt,rt,B,0,0);
sort(c+1,c+cnt+1);
for(int i=1;i<=cnt;i++){
if(i!=cnt) printf("%d ",c[i]);
else printf("%d",c[i]);
}
return 0;
}
/*
3
3 2 1
4 5 6
*/