LOJ#3048. 「十二省联考 2019」异或粽子 Trie

原文链接www.cnblogs.com/zhouzhendong/p/LOJ3048.html

题解

  $O(n\log^2 {a_i})$ 的做法比较简单:

  1. 求出第 k 大的是什么: 二分答案,在Trie树上统计一下答案。

  2. 求出前 k 大的和:已经知道了第 k 大的数值,那么,只要再在 Trie 树上走一趟就好了。

  这两部分直接暴力都是 $O(n\log^2 a_i)$ 的。

  那么我们来稍微优化一下:

  对于 1. ,我们改成 Trie 树上二分,变成了 $O(n\log a_i)$ 的。

  对于 2. ,我们发现:这个 Trie 最多有 O(n) 个分叉点。而且在 2 操作中,搜索子树的次数也是 O(n) 的。那么,如果我们可以预处理出每一个节点代表的子树中,每一种二进制位的出现次数,就可以解决问题。处理出全部的信息显然是不可能的,我们考虑只处理每一个分叉的两个儿子,由于它是 O(n) 的,所以时间复杂度降到 $O(n \log a_i)$ 。

  接下来的代码是 $O(n\log ^2 a_i)$ 的代码。

  至于我开的数据范围以及 __int128 为何使用,我给出一个解释:某神仙把数据范围改成了 n<=7.5e5, k<=n(n-1)/2 。然而我写两个log卡常数没有卡过去

代码

#include <bits/stdc++.h>
#define clr(x) memset(x,0,sizeof (x))
#define For(i,a,b) for (int i=a;i<=b;i++)
#define Fod(i,b,a) for (int i=b;i>=a;i--)
#define pb(x) push_back(x)
#define mp(x,y) make_pair(x,y)
#define fi first
#define se second
#define _SEED_ ('C'+'L'+'Y'+'A'+'K'+'I'+'O'+'I')
#define outval(x) printf(#x" = %d\n",x)
#define outvec(x) printf("vec "#x" = ");for (auto _v : x)printf("%d ",_v);puts("")
#define outtag(x) puts("----------"#x"----------")
#define outarr(a,L,R) printf(#a"[%d...%d] = ",L,R);\
						For(_v2,L,R)printf("%d ",a[_v2]);puts("");
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
typedef vector <int> vi;
typedef unsigned uint;
namespace IO{
    const int Len=1<<20;
    char Ibuf[Len+1],*Is=Ibuf,*It=Ibuf;
    char gc(){
        if (Is==It){
            It=(Is=Ibuf)+fread(Ibuf,1,Len,stdin);
            if (Is==It)
                return EOF;
        }
        return *Is++;
    }
}
#define getchar IO::gc
LL read(){
	LL x=0,f=0;
	char ch=getchar();
	while (!isdigit(ch))
    	f|=ch=='-',ch=getchar();
	while (isdigit(ch))
		x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
	return f?-x:x;
}
const int N=750005;
int n,bt=31;
LL k;
LL a[N],s[N];
int size[N*33],Next[N*33][2];
uint val[N*33];
int cnt=1;
void Insert(LL v){
	int x=1,t;
	size[x]++;
	Fod(i,bt,0){
		t=v>>i&1;
		if (!Next[x][t])
			Next[x][t]=++cnt,val[cnt]=val[x]|((LL)t)<<i;
		x=Next[x][t];
		size[x]++;
	}
}
LL calc(int a,int b){
	return a==b?((LL)size[a]*(size[a]-1))>>1:(LL)size[a]*size[b];
}
__int128 ans=0;
int *c0,*c1;
void get(int x,int d){
	if (Next[x][0]){
		c0[d]+=size[Next[x][0]];
		get(Next[x][0],d-1);
	}
	if (Next[x][1]){
		c1[d]+=size[Next[x][1]];
		get(Next[x][1],d-1);
	}
}
void calc2(int x,int y,int d){
	static int bitx[2][33],bity[2][33];
	clr(bitx),clr(bity);
	c0=bitx[0],c1=bitx[1],get(x,d);
	if (x!=y){
		c0=bity[0],c1=bity[1],get(y,d);
		For(i,0,d)
			ans+=(__int128)(bitx[0][i]*bity[1][i]+bitx[1][i]*bity[0][i])<<i;
	}
	else {
		For(i,0,d)
			ans+=(__int128)(bitx[0][i]*bitx[1][i])<<i;
	}
	ans+=(__int128)calc(x,y)*(val[x]^val[y]);
}
LL now=0,ub,res_cnt=0,tmp=0;
void dfs(int x,int y,int d){
	uint v=val[x]^val[y];
	if (v>=ub)
		return (void)(tmp+=calc(x,y));
	if ((1LL<<(d+1))+v<=ub)
		return;
	if (Next[x][0]){
		if (Next[y][0])
			dfs(Next[x][0],Next[y][0],d-1);
		if (Next[y][1])
			dfs(Next[x][0],Next[y][1],d-1);
	}
	if (Next[x][1]){
		if (Next[y][1])
			dfs(Next[x][1],Next[y][1],d-1);
		if (Next[y][0]&&x!=y)
			dfs(Next[x][1],Next[y][0],d-1);
	}
}
void dfs2(int x,int y,int d){
	uint v=val[x]^val[y];
	if (v>=ub)
		return calc2(x,y,d);
	if ((1LL<<(d+1))+v<=ub)
		return;
	if (Next[x][0]){
		if (Next[y][0])
			dfs2(Next[x][0],Next[y][0],d-1);
		if (Next[y][1])
			dfs2(Next[x][0],Next[y][1],d-1);
	}
	if (Next[x][1]){
		if (Next[y][1])
			dfs2(Next[x][1],Next[y][1],d-1);
		if (Next[y][0]&&x!=y)
			dfs2(Next[x][1],Next[y][0],d-1);
	}
}
void write(__int128 x){
	if (x>9)
		write(x/10);
	putchar('0'+x%10);
}
int main(){
	n=read(),k=read();
	For(i,1,n)
		a[i]=read();
	s[0]=0;
	For(i,1,n)
		s[i]=s[i-1]^a[i];
	s[++n]=0;
	For(i,1,n)
		Insert(s[i]);
	Fod(i,bt,0){
		ub=now|1LL<<i;
		tmp=0,dfs(1,1,bt);
		if (tmp>=k)
			now=ub,res_cnt=tmp;
	}
	ub=now;
	dfs2(1,1,bt);
	ans-=(__int128)(res_cnt-k)*ub;
	write(ans),putchar('\n');
	return 0;
}

  

posted @ 2019-04-09 20:07  zzd233  阅读(311)  评论(0编辑  收藏  举报