[CSP-S模拟测试]:english(可持久化Trie+启发式合并)

题目传送门(内部题24)


输入格式

第一行有$3$个整数$n,opt$,$opt$的意义将在输出格式中提到。
第二行有$n$个整数,第$i$个整数表示$a_i$。


输出格式

若$opt=1$,输出一行一个整数表示${ans}_1$。
若$opt=2$,输出一行一个整数表示${ans}_2$。
若$opt=3$,输出两行,第一行一个整数${ans}_1$,第二行一个整数${ans}_2$。


样例

样例输入:

3 3
6 1 3

样例输出:

78
6


数据范围与提示

对于所有数据,$1\leqslant n\leqslant {10}^5,1\leqslant opt\leqslant 3,0\leqslant a_i\leqslant {10}^6$。


题解

对于每个数$a_x$,用单调栈找出它作为最大值的区间$[l_x,r_x]$,所有区间只有包含和不相交关系,没有相交关系,而且所以区间构成了一棵二叉树。
对每个区间$[l_x,r_x]$维护一棵$01trie$树$T_x$。
对每个区间$[l_x,r_x]$维护一个数组$f_x$,其中$f_{x,j}$表示该区间中第$j$位为$1$的数有多少个。
所以区间构成了一棵二叉树,可以对区间进行启发式合并,对于$a_x$控制的区间$[l_x,r_x]$,找到它的左右儿子$lch:[l_x,x−1]$和$rch:[x+1,r_x]$,我们只需要考虑所有包含$x$的区间的答案,而且这些区间的最大值都是$a_x$。
若左区间的长度$<$右区间的长度,我们可以枚举左区间中的每个数$a_i$。
对于${ans}_1$,我们可以分别统计每一个二进制位的答案,若$a_i$的第$j$位是$0$,那么第$j$位的贡献就是$2^jf{rch,j}$,若$a_i$的第$j$位是$1$,情况类似。同时,将 $a_i$更新到$f_x$中。
对于${ans}_2$,问题就转化成右区间中有多少个数$v$满足$v\ xor\ a_i>a_x$,可以在$T_{rch}$中查询。同时,将$a_i$插入到$trie$树$T_x$中。
时间复杂度:$\Theta(n\log n\log v)$。
期望得分:$100$分。
实际得分:$100$分。


代码时刻

#include<bits/stdc++.h>
using namespace std;
int n,opt;
int a[100001],c[30],s[100001][30],sta[100001],sum[100001],l[100001],r[100001];
long long flag[30],d[30];
int rt[100001];
int trie[50000000][2],w[50000000],cnt;
long long ans1,ans2;
void add(int x,int y)
{
	sum[y]++;
	for(int i=0;i<=21;i++)
	{
		s[y][i]+=x&1;
		x>>=1;
	}
}
void insert(int x,int l,int r)
{
	for(int i=21;i>=0;i--)
	{
		int p=(x>>i)&1;
		w[l]=w[r]+1;
		trie[l][p^1]=trie[r][p^1];
		trie[l][p]=++cnt;
		l=trie[l][p];
		r=trie[r][p];
	}
	w[l]=w[r]+1;
}
int ask(int x,int y,int l,int r)
{
	int res=0,ans=0;
	for(int i=21;i>=0;i--)
		if((y>>i)&1)
			if(res+flag[i]>x)
			{
				ans+=w[trie[r][0]]-w[trie[l][0]];
				l=trie[l][1];
				r=trie[r][1];
			}
			else
			{
				l=trie[l][0];
				r=trie[r][0];
				res+=flag[i];
			}
		else
			if(res+flag[i]>x)
			{
				ans+=w[trie[r][1]]-w[trie[l][1]];
				l=trie[l][0];
				r=trie[r][0];
			}
			else
			{
				l=trie[l][1];
				r=trie[r][1];
				res+=flag[i];
			}
	return ans;
}
int main()
{
	scanf("%d%d",&n,&opt);
	for(int i=1;i<=n;i++)
		scanf("%d",&a[i]);
	flag[0]=1;for(int i=1;i<=21;i++)flag[i]=flag[i-1]<<1;
	cnt=n;
	for(int i=1;i<=n;i++)
	{
		rt[i]=i;
		add(a[i],i);
		insert(a[i],rt[i],rt[i-1]);
		while(sta[0]&&a[sta[sta[0]]]<=a[i])
			r[sta[sta[0]--]]=i-1;
		l[i]=sta[sta[0]]+1;
		sta[++sta[0]]=i;
		for(int j=0;j<=21;j++)
			s[i][j]+=s[i-1][j];
		sum[i]+=sum[i-1];
	}
	while(sta[0])r[sta[sta[0]--]]=n;
	for(int i=1;i<=n;i++)
	{
		long long res1=0,res2=0;
		if(i-l[i]<=r[i]-i)
		{
			for(int j=0;j<=21;j++)
				c[j]=s[r[i]][j]-s[i-1][j];
			for(int j=l[i];j<=i;j++)
			{
				for(int k=0;k<=21;k++)
				{
					if((a[j]>>k)&1)d[k]=sum[r[i]]-sum[i-1]-c[k];
					else d[k]=c[k];
					res1=(res1+d[k]*flag[k])%1000000007;
				}
				res2=(res2+ask(a[i],a[j],rt[i-1],rt[r[i]]))%1000000007;
			}
		}
		else
		{
			for(int j=0;j<=21;j++)
				c[j]=s[i][j]-s[l[i]-1][j];
			for(int j=i;j<=r[i];j++)
			{
				for(int k=0;k<=21;k++)
				{
					if((a[j]>>k)&1)d[k]=sum[i]-sum[l[i]-1]-c[k];
					else d[k]=c[k];
					res1=(res1+d[k]*flag[k])%1000000007;
				}
				res2=(res2+ask(a[i],a[j],rt[l[i]-1],rt[i]))%1000000007;
			}
		}
		ans1=(ans1+res1*a[i])%1000000007;
		ans2=(ans2+res2*a[i])%1000000007;
	}
	switch(opt)
	{
		case 1:printf("%lld",ans1);break;
		case 2:printf("%lld",ans2);break;
		case 3:printf("%lld\n%lld",ans1,ans2);break;
	}
	return 0;
}

rp++

posted @ 2019-09-05 17:20  HEOI-动动  阅读(282)  评论(0编辑  收藏  举报