[codechef]SnackDown 2017 Online Elimination Round Prefix XOR

预处理后主席树维护

首先得出最后的答案为 ∑ri=lmin(right[i],r)−i+1∑i=lrmin(right[i],r)−i+1

如果不会主席树可以先去这里学

然后洛谷里有主席树模板可以练手

洛谷里面的题解都将的很好,不会打可以去看看。
step1

那么首要问题就是如何求出right[i]right[i]

考虑当i--j-1是上升时使区间i--j是上升的

即sum[i-1]sum[j-1]<=sum[i-1]sum[j]

观察到两边有差异的是sum[j-1]和sum[j] 也就意味着sum[j-1]和sum[j]的不同会对i的取值有限制

假设k为二进制下sum[j-1]与sum[j]最高的不同位
如果sum[j]此位为1对i的限制是sum[i-1]的此位不能为1
如果sum[j]此位为0对i的限制是sum[i-1]的此位不能为0

通过枚举每一位的限制即可得ri[i]ri[i]的最大合理值
step2

接下来就是利用主席数维护答案了

∑ri=lmin(right[i],r)−i+1∑i=lrmin(right[i],r)−i+1

我们可以对于所有的ri[i]ri[i]建设主席数 维护两个值

1.所有ri[i]ri[i]在i--j的范围内总和sum

2.所有ri[i]ri[i]在i--j的范围内有几个cnt

最后的答案及为l--r内ri[i]ri[i]的值在l--r内的sum+l--r内ri[i]ri[i]的值大于r的cnt××r-l--r所有数字和+(r-l+1)

以上是JHJ大佬的题解。题解传送门

但是大佬还有一些地方没标注,比如p数组是干啥的?

再请教JHJ大佬之后,在代码中加入了自己的理解。

然后。。。疯狂改bug,就在我想放弃,把代码交到网站上当作保存时奇迹出现了。

我TM居然过了。

下面是代码,有注释。

#include<bits/stdc++.h>
using namespace std;
const long long N=10000005;
long long p[35][2];
long long n,m,t,tot,lf[N],ri[N],sum[N],cnt[N],a[N];
long long f[N],root[N],x,y,l,r,anss;
long long read()
{
	long long x=0;char c='a';bool f=false;
	while (c<'0'||c>'9') {c=getchar();if (c=='-') f=1;}
	while (c>='0'&&c<='9') x=x*10+c-'0',c=getchar();
	if (f)x=x*-1; 
	return x;
}

long long build(long long l,long long r)
{
	long long rt=++tot;
	if (l<r)
	{
		long long mid=(l+r)>>1;
		lf[rt]=build(l,mid);
		ri[rt]=build(mid+1,r);
	}
	return rt;
}

long long updata(long long pre,long long l,long long r,long long k)
{
	long long rt=++tot;
	lf[rt]=lf[pre];ri[rt]=ri[pre];
	sum[rt]=sum[pre]+k;cnt[rt]=cnt[pre]+1;
	if (l<r)
	{
		long long mid=(l+r)>>1;
		if (mid>=k) lf[rt]=updata(lf[pre],l,mid,k);
		else ri[rt]=updata(ri[pre],mid+1,r,k);
	}
	return rt;
}

long long getsum(long long x,long long y,long long L,long long R,long long l,long long r)//这个区间的ri的和 
{
	long long ans=0,mid=(l+r)>>1;
	if (L<=l&&r<=R) return sum[y]-sum[x];
	if (L<=mid) ans+=getsum(lf[x],lf[y],L,R,l,mid);
	if (R>mid) ans+=getsum(ri[x],ri[y],L,R,mid+1,r);
	return ans;
}

long long getcnt(long long x,long long y,long long L,long long R,long long l,long long r)//这个区间的数的个数 
{
	long long ans=0,mid=(l+r)>>1;
	if (L<=l&&r<=R) return cnt[y]-cnt[x];
	if (L<=mid) ans+=getcnt(lf[x],lf[y],L,R,l,mid);
	if (R>mid) ans+=getcnt(ri[x],ri[y],L,R,mid+1,r);
	return ans;
}

long long get(long long x,long long y)
{
	return x*(x-1)/2-y*(y-1)/2;
}

int main()
{
	//freopen("1.in","r",stdin);
	//freopen("1.out","w",stdout);
	n=read();t=read();//n个数,后面要乘t对n取模 
	for (long long i=1;i<=n;i++)
	a[i]=read(),a[i]^=a[i-1];//前缀异或和 
	m=read();
	memset(p,63,sizeof(p));
	for(long long i=n;i>=1;i--)
    {
        f[i]=n;
        for(long long j=30;j>=0;j--)
		f[i]=min(f[i],p[j][(a[i-1]>>j)&1]-1);
        for(long long j=30;j>=0;j--)
		if(((a[i]>>j)&1)^((a[i-1]>>j)&1))
		{
            p[j][(a[i]>>j)&1]=min(p[j][(a[i]>>j)&1],i);//f[x][y]表示二进制下第x位为y的情况,最远能取到哪里 
			break;//其中y为0或1,最远取值的意思是最远能取到哪个a[i] 
        }
    }
	root[0]=build(1,n);
	for (long long i=1;i<=n;i++)
	root[i]=updata(root[i-1],1,n,f[i]);
	for (long long i=1;i<=m;i++)
	{
		x=read();y=read();
		x=(x+anss*t)%n+1;y=(y+anss*t)%n+1;
		l=min(x,y);r=max(x,y);
		anss=getsum(root[l-1],root[r],l,r,1,n);
		anss+=r*getcnt(root[l-1],root[r],r+1,n,1,n);
		anss-=get(r,l-1);
		printf("%lld\n",anss);
	}
	return 0;
}
posted @ 2019-02-24 16:29  xzjds  阅读(153)  评论(0编辑  收藏  举报