[codechef]SnackDown 2017 Online Elimination Round Prefix XOR
预处理后主席树维护
首先得出最后的答案为 ∑ri=lmin(right[i],r)−i+1∑i=lrmin(right[i],r)−i+1
洛谷里面的题解都将的很好,不会打可以去看看。
step1
那么首要问题就是如何求出right[i]right[i]
考虑当i--j-1是上升时使区间i--j是上升的
即sum[i-1]sum[j-1]<=sum[i-1]sum[j]
观察到两边有差异的是sum[j-1]和sum[j] 也就意味着sum[j-1]和sum[j]的不同会对i的取值有限制
假设k为二进制下sum[j-1]与sum[j]最高的不同位
如果sum[j]此位为1对i的限制是sum[i-1]的此位不能为1
如果sum[j]此位为0对i的限制是sum[i-1]的此位不能为0
通过枚举每一位的限制即可得ri[i]ri[i]的最大合理值
step2
接下来就是利用主席数维护答案了
∑ri=lmin(right[i],r)−i+1∑i=lrmin(right[i],r)−i+1
我们可以对于所有的ri[i]ri[i]建设主席数 维护两个值
1.所有ri[i]ri[i]在i--j的范围内总和sum
2.所有ri[i]ri[i]在i--j的范围内有几个cnt
最后的答案及为l--r内ri[i]ri[i]的值在l--r内的sum+l--r内ri[i]ri[i]的值大于r的cnt××r-l--r所有数字和+(r-l+1)
以上是JHJ大佬的题解。题解传送门
但是大佬还有一些地方没标注,比如p数组是干啥的?
再请教JHJ大佬之后,在代码中加入了自己的理解。
然后。。。疯狂改bug,就在我想放弃,把代码交到网站上当作保存时奇迹出现了。
我TM居然过了。
下面是代码,有注释。
#include<bits/stdc++.h>
using namespace std;
const long long N=10000005;
long long p[35][2];
long long n,m,t,tot,lf[N],ri[N],sum[N],cnt[N],a[N];
long long f[N],root[N],x,y,l,r,anss;
long long read()
{
long long x=0;char c='a';bool f=false;
while (c<'0'||c>'9') {c=getchar();if (c=='-') f=1;}
while (c>='0'&&c<='9') x=x*10+c-'0',c=getchar();
if (f)x=x*-1;
return x;
}
long long build(long long l,long long r)
{
long long rt=++tot;
if (l<r)
{
long long mid=(l+r)>>1;
lf[rt]=build(l,mid);
ri[rt]=build(mid+1,r);
}
return rt;
}
long long updata(long long pre,long long l,long long r,long long k)
{
long long rt=++tot;
lf[rt]=lf[pre];ri[rt]=ri[pre];
sum[rt]=sum[pre]+k;cnt[rt]=cnt[pre]+1;
if (l<r)
{
long long mid=(l+r)>>1;
if (mid>=k) lf[rt]=updata(lf[pre],l,mid,k);
else ri[rt]=updata(ri[pre],mid+1,r,k);
}
return rt;
}
long long getsum(long long x,long long y,long long L,long long R,long long l,long long r)//这个区间的ri的和
{
long long ans=0,mid=(l+r)>>1;
if (L<=l&&r<=R) return sum[y]-sum[x];
if (L<=mid) ans+=getsum(lf[x],lf[y],L,R,l,mid);
if (R>mid) ans+=getsum(ri[x],ri[y],L,R,mid+1,r);
return ans;
}
long long getcnt(long long x,long long y,long long L,long long R,long long l,long long r)//这个区间的数的个数
{
long long ans=0,mid=(l+r)>>1;
if (L<=l&&r<=R) return cnt[y]-cnt[x];
if (L<=mid) ans+=getcnt(lf[x],lf[y],L,R,l,mid);
if (R>mid) ans+=getcnt(ri[x],ri[y],L,R,mid+1,r);
return ans;
}
long long get(long long x,long long y)
{
return x*(x-1)/2-y*(y-1)/2;
}
int main()
{
//freopen("1.in","r",stdin);
//freopen("1.out","w",stdout);
n=read();t=read();//n个数,后面要乘t对n取模
for (long long i=1;i<=n;i++)
a[i]=read(),a[i]^=a[i-1];//前缀异或和
m=read();
memset(p,63,sizeof(p));
for(long long i=n;i>=1;i--)
{
f[i]=n;
for(long long j=30;j>=0;j--)
f[i]=min(f[i],p[j][(a[i-1]>>j)&1]-1);
for(long long j=30;j>=0;j--)
if(((a[i]>>j)&1)^((a[i-1]>>j)&1))
{
p[j][(a[i]>>j)&1]=min(p[j][(a[i]>>j)&1],i);//f[x][y]表示二进制下第x位为y的情况,最远能取到哪里
break;//其中y为0或1,最远取值的意思是最远能取到哪个a[i]
}
}
root[0]=build(1,n);
for (long long i=1;i<=n;i++)
root[i]=updata(root[i-1],1,n,f[i]);
for (long long i=1;i<=m;i++)
{
x=read();y=read();
x=(x+anss*t)%n+1;y=(y+anss*t)%n+1;
l=min(x,y);r=max(x,y);
anss=getsum(root[l-1],root[r],l,r,1,n);
anss+=r*getcnt(root[l-1],root[r],r+1,n,1,n);
anss-=get(r,l-1);
printf("%lld\n",anss);
}
return 0;
}