[CSP-S模拟测试]:english(可持久化Trie+启发式合并)
题目传送门(内部题24)
输入格式
第一行有$3$个整数$n,opt$,$opt$的意义将在输出格式中提到。
第二行有$n$个整数,第$i$个整数表示$a_i$。
输出格式
若$opt=1$,输出一行一个整数表示${ans}_1$。
若$opt=2$,输出一行一个整数表示${ans}_2$。
若$opt=3$,输出两行,第一行一个整数${ans}_1$,第二行一个整数${ans}_2$。
样例
样例输入:
3 3
6 1 3
样例输出:
78
6
数据范围与提示
对于所有数据,$1\leqslant n\leqslant {10}^5,1\leqslant opt\leqslant 3,0\leqslant a_i\leqslant {10}^6$。
题解
对于每个数$a_x$,用单调栈找出它作为最大值的区间$[l_x,r_x]$,所有区间只有包含和不相交关系,没有相交关系,而且所以区间构成了一棵二叉树。
对每个区间$[l_x,r_x]$维护一棵$01trie$树$T_x$。
对每个区间$[l_x,r_x]$维护一个数组$f_x$,其中$f_{x,j}$表示该区间中第$j$位为$1$的数有多少个。
所以区间构成了一棵二叉树,可以对区间进行启发式合并,对于$a_x$控制的区间$[l_x,r_x]$,找到它的左右儿子$lch:[l_x,x−1]$和$rch:[x+1,r_x]$,我们只需要考虑所有包含$x$的区间的答案,而且这些区间的最大值都是$a_x$。
若左区间的长度$<$右区间的长度,我们可以枚举左区间中的每个数$a_i$。
对于${ans}_1$,我们可以分别统计每一个二进制位的答案,若$a_i$的第$j$位是$0$,那么第$j$位的贡献就是$2^jf{rch,j}$,若$a_i$的第$j$位是$1$,情况类似。同时,将 $a_i$更新到$f_x$中。
对于${ans}_2$,问题就转化成右区间中有多少个数$v$满足$v\ xor\ a_i>a_x$,可以在$T_{rch}$中查询。同时,将$a_i$插入到$trie$树$T_x$中。
时间复杂度:$\Theta(n\log n\log v)$。
期望得分:$100$分。
实际得分:$100$分。
代码时刻
#include<bits/stdc++.h> using namespace std; int n,opt; int a[100001],c[30],s[100001][30],sta[100001],sum[100001],l[100001],r[100001]; long long flag[30],d[30]; int rt[100001]; int trie[50000000][2],w[50000000],cnt; long long ans1,ans2; void add(int x,int y) { sum[y]++; for(int i=0;i<=21;i++) { s[y][i]+=x&1; x>>=1; } } void insert(int x,int l,int r) { for(int i=21;i>=0;i--) { int p=(x>>i)&1; w[l]=w[r]+1; trie[l][p^1]=trie[r][p^1]; trie[l][p]=++cnt; l=trie[l][p]; r=trie[r][p]; } w[l]=w[r]+1; } int ask(int x,int y,int l,int r) { int res=0,ans=0; for(int i=21;i>=0;i--) if((y>>i)&1) if(res+flag[i]>x) { ans+=w[trie[r][0]]-w[trie[l][0]]; l=trie[l][1]; r=trie[r][1]; } else { l=trie[l][0]; r=trie[r][0]; res+=flag[i]; } else if(res+flag[i]>x) { ans+=w[trie[r][1]]-w[trie[l][1]]; l=trie[l][0]; r=trie[r][0]; } else { l=trie[l][1]; r=trie[r][1]; res+=flag[i]; } return ans; } int main() { scanf("%d%d",&n,&opt); for(int i=1;i<=n;i++) scanf("%d",&a[i]); flag[0]=1;for(int i=1;i<=21;i++)flag[i]=flag[i-1]<<1; cnt=n; for(int i=1;i<=n;i++) { rt[i]=i; add(a[i],i); insert(a[i],rt[i],rt[i-1]); while(sta[0]&&a[sta[sta[0]]]<=a[i]) r[sta[sta[0]--]]=i-1; l[i]=sta[sta[0]]+1; sta[++sta[0]]=i; for(int j=0;j<=21;j++) s[i][j]+=s[i-1][j]; sum[i]+=sum[i-1]; } while(sta[0])r[sta[sta[0]--]]=n; for(int i=1;i<=n;i++) { long long res1=0,res2=0; if(i-l[i]<=r[i]-i) { for(int j=0;j<=21;j++) c[j]=s[r[i]][j]-s[i-1][j]; for(int j=l[i];j<=i;j++) { for(int k=0;k<=21;k++) { if((a[j]>>k)&1)d[k]=sum[r[i]]-sum[i-1]-c[k]; else d[k]=c[k]; res1=(res1+d[k]*flag[k])%1000000007; } res2=(res2+ask(a[i],a[j],rt[i-1],rt[r[i]]))%1000000007; } } else { for(int j=0;j<=21;j++) c[j]=s[i][j]-s[l[i]-1][j]; for(int j=i;j<=r[i];j++) { for(int k=0;k<=21;k++) { if((a[j]>>k)&1)d[k]=sum[i]-sum[l[i]-1]-c[k]; else d[k]=c[k]; res1=(res1+d[k]*flag[k])%1000000007; } res2=(res2+ask(a[i],a[j],rt[l[i]-1],rt[i]))%1000000007; } } ans1=(ans1+res1*a[i])%1000000007; ans2=(ans2+res2*a[i])%1000000007; } switch(opt) { case 1:printf("%lld",ans1);break; case 2:printf("%lld",ans2);break; case 3:printf("%lld\n%lld",ans1,ans2);break; } return 0; }
rp++