LOJ#3048. 「十二省联考 2019」异或粽子 Trie
原文链接www.cnblogs.com/zhouzhendong/p/LOJ3048.html
题解
$O(n\log^2 {a_i})$ 的做法比较简单:
1. 求出第 k 大的是什么: 二分答案,在Trie树上统计一下答案。
2. 求出前 k 大的和:已经知道了第 k 大的数值,那么,只要再在 Trie 树上走一趟就好了。
这两部分直接暴力都是 $O(n\log^2 a_i)$ 的。
那么我们来稍微优化一下:
对于 1. ,我们改成 Trie 树上二分,变成了 $O(n\log a_i)$ 的。
对于 2. ,我们发现:这个 Trie 最多有 O(n) 个分叉点。而且在 2 操作中,搜索子树的次数也是 O(n) 的。那么,如果我们可以预处理出每一个节点代表的子树中,每一种二进制位的出现次数,就可以解决问题。处理出全部的信息显然是不可能的,我们考虑只处理每一个分叉的两个儿子,由于它是 O(n) 的,所以时间复杂度降到 $O(n \log a_i)$ 。
接下来的代码是 $O(n\log ^2 a_i)$ 的代码。
至于我开的数据范围以及 __int128 为何使用,我给出一个解释:某神仙把数据范围改成了 n<=7.5e5, k<=n(n-1)/2 。然而我写两个log卡常数没有卡过去。
代码
#include <bits/stdc++.h> #define clr(x) memset(x,0,sizeof (x)) #define For(i,a,b) for (int i=a;i<=b;i++) #define Fod(i,b,a) for (int i=b;i>=a;i--) #define pb(x) push_back(x) #define mp(x,y) make_pair(x,y) #define fi first #define se second #define _SEED_ ('C'+'L'+'Y'+'A'+'K'+'I'+'O'+'I') #define outval(x) printf(#x" = %d\n",x) #define outvec(x) printf("vec "#x" = ");for (auto _v : x)printf("%d ",_v);puts("") #define outtag(x) puts("----------"#x"----------") #define outarr(a,L,R) printf(#a"[%d...%d] = ",L,R);\ For(_v2,L,R)printf("%d ",a[_v2]);puts(""); using namespace std; typedef long long LL; typedef unsigned long long ULL; typedef vector <int> vi; typedef unsigned uint; namespace IO{ const int Len=1<<20; char Ibuf[Len+1],*Is=Ibuf,*It=Ibuf; char gc(){ if (Is==It){ It=(Is=Ibuf)+fread(Ibuf,1,Len,stdin); if (Is==It) return EOF; } return *Is++; } } #define getchar IO::gc LL read(){ LL x=0,f=0; char ch=getchar(); while (!isdigit(ch)) f|=ch=='-',ch=getchar(); while (isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar(); return f?-x:x; } const int N=750005; int n,bt=31; LL k; LL a[N],s[N]; int size[N*33],Next[N*33][2]; uint val[N*33]; int cnt=1; void Insert(LL v){ int x=1,t; size[x]++; Fod(i,bt,0){ t=v>>i&1; if (!Next[x][t]) Next[x][t]=++cnt,val[cnt]=val[x]|((LL)t)<<i; x=Next[x][t]; size[x]++; } } LL calc(int a,int b){ return a==b?((LL)size[a]*(size[a]-1))>>1:(LL)size[a]*size[b]; } __int128 ans=0; int *c0,*c1; void get(int x,int d){ if (Next[x][0]){ c0[d]+=size[Next[x][0]]; get(Next[x][0],d-1); } if (Next[x][1]){ c1[d]+=size[Next[x][1]]; get(Next[x][1],d-1); } } void calc2(int x,int y,int d){ static int bitx[2][33],bity[2][33]; clr(bitx),clr(bity); c0=bitx[0],c1=bitx[1],get(x,d); if (x!=y){ c0=bity[0],c1=bity[1],get(y,d); For(i,0,d) ans+=(__int128)(bitx[0][i]*bity[1][i]+bitx[1][i]*bity[0][i])<<i; } else { For(i,0,d) ans+=(__int128)(bitx[0][i]*bitx[1][i])<<i; } ans+=(__int128)calc(x,y)*(val[x]^val[y]); } LL now=0,ub,res_cnt=0,tmp=0; void dfs(int x,int y,int d){ uint v=val[x]^val[y]; if (v>=ub) return (void)(tmp+=calc(x,y)); if ((1LL<<(d+1))+v<=ub) return; if (Next[x][0]){ if (Next[y][0]) dfs(Next[x][0],Next[y][0],d-1); if (Next[y][1]) dfs(Next[x][0],Next[y][1],d-1); } if (Next[x][1]){ if (Next[y][1]) dfs(Next[x][1],Next[y][1],d-1); if (Next[y][0]&&x!=y) dfs(Next[x][1],Next[y][0],d-1); } } void dfs2(int x,int y,int d){ uint v=val[x]^val[y]; if (v>=ub) return calc2(x,y,d); if ((1LL<<(d+1))+v<=ub) return; if (Next[x][0]){ if (Next[y][0]) dfs2(Next[x][0],Next[y][0],d-1); if (Next[y][1]) dfs2(Next[x][0],Next[y][1],d-1); } if (Next[x][1]){ if (Next[y][1]) dfs2(Next[x][1],Next[y][1],d-1); if (Next[y][0]&&x!=y) dfs2(Next[x][1],Next[y][0],d-1); } } void write(__int128 x){ if (x>9) write(x/10); putchar('0'+x%10); } int main(){ n=read(),k=read(); For(i,1,n) a[i]=read(); s[0]=0; For(i,1,n) s[i]=s[i-1]^a[i]; s[++n]=0; For(i,1,n) Insert(s[i]); Fod(i,bt,0){ ub=now|1LL<<i; tmp=0,dfs(1,1,bt); if (tmp>=k) now=ub,res_cnt=tmp; } ub=now; dfs2(1,1,bt); ans-=(__int128)(res_cnt-k)*ub; write(ans),putchar('\n'); return 0; }