luogu P4428 [BJOI2018]二进制

luogu

先考虑怎样的二进制串才会被3整除.可以发现如果二进制位第\(0,2,4...2n\)位如果为\(1\),那么在模3意义下为1,如果二进制位第\(1,3,5...2n+1\)位如果为\(1\),那么在模3意义下为-1.所以也就是位置上是1的奇二进制位个数减位置上是1的偶二进制位个数要被3整除

在这种条件下,如果区间内1的个数为偶数显然可以从最低位开始依次放使得被3整除,如果为奇数,那么先把除了最后三个1以外的1按照偶数的情况处理,然后这三个1中间各插入一个0,也就是\(...0101011...1\).那么,不合法的情况就只剩下有区间内奇数个1同时0的个数\(<2\),或者是区间内只有一个1

合法区间比较麻烦,改为求总区间个数-不合法区间个数.为了不算重,把不合法条件改为只剩下有区间内奇数个1同时0的个数\(<2\),或者是区间内只有一个1同时\(\ge 2\).我们用线段树维护这些区间个数,对每个节点记一个\(ls_{i,j}\)表示左端点为这个线段树节点对应区间左端点的区间中,1的个数奇偶性\(0/1\),0的个数为\(0/1\)的区间个数,\(rs_{i,j}\)表示的是右端点为线段树节点右端点的相应的区间个数;\(lz_{i,j}\)表示左端点为线段树节点左端点的区间中,1的个数为\(0/1\),0的个数为\(0/1/\ge 2\)的区间个数,\(rz_{i,j}\)表示的是右端点为线段树节点右端点的相应的区间个数.以及分别记录区间\(0/1\)个数和不合法区间个数,每次合并两个节点,就计算跨越这两个节点的区间信息,可能需要一点点讨论,这里不再赘述

#include<bits/stdc++.h>
#define LL long long
#define uLL unsigned long long
#define db double

using namespace std;
const int N=1e5+10;
int rd()
{
	int x=0,w=1;char ch=0;
	while(ch<'0'||ch>'9'){if(ch=='-') w=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+(ch^48);ch=getchar();}
	return x*w;
}
struct node
{
	LL c0,c1,s;
	LL ls[2][2],rs[2][2];
	LL lz[2][3],rz[2][3];
	void clr(){memset(ls,0,sizeof(ls)),memset(rs,0,sizeof(rs)),memset(lz,0,sizeof(lz)),memset(rz,0,sizeof(rz)),c0=c1=s=0;}
	node(){}
	node(int x)
	{
		memset(ls,0,sizeof(ls)),memset(rs,0,sizeof(rs)),memset(lz,0,sizeof(lz)),memset(rz,0,sizeof(rz)),c0=c1=s=0;
		if(!x)
		{
			c0=1;
			ls[0][1]=rs[0][1]=lz[0][1]=rz[0][1]=1;
		}
		else
		{
			s=c1=1;
			ls[1][0]=rs[1][0]=lz[1][0]=rz[1][0]=1;
		}
	}
}s[N<<2],an;
node merg(node aa,node bb)
{
	an.clr();
	an.c0=aa.c0+bb.c0;
	an.c1=aa.c1+bb.c1;
	an.s=aa.s+bb.s;
	for(int i=0;i<=1;++i)
		for(int j=0;j<=1;++j)
		{
			an.ls[i][j]+=aa.ls[i][j];
			an.rs[i][j]+=bb.rs[i][j];
			if(aa.c0+j<=1) an.ls[(aa.c1&1)^i][aa.c0+j]+=bb.ls[i][j];
			if(bb.c0+j<=1) an.rs[(bb.c1&1)^i][bb.c0+j]+=aa.rs[i][j];
		}
	for(int i=0;i<=1;++i)
		for(int j=0;j<=1;++j)
			for(int k=0;k<=1;++k)
				for(int l=0;l<=1;++l)
					if((i^k)==1&&j+l<=1) an.s+=aa.rs[i][j]*bb.ls[k][l];
	for(int i=0;i<=1;++i)
		for(int j=0;j<=2;++j)
		{
			an.lz[i][j]+=aa.lz[i][j];
			an.rz[i][j]+=bb.rz[i][j];
			if(aa.c1+i<=1) an.lz[aa.c1+i][min(aa.c0+j,2ll)]+=bb.lz[i][j];
			if(bb.c1+i<=1) an.rz[bb.c1+i][min(bb.c0+j,2ll)]+=aa.rz[i][j];
		}
	for(int i=0;i<=1;++i)
		for(int j=0;j<=2;++j)
			for(int k=0;k<=1;++k)
				for(int l=0;l<=2;++l)
					if(i+k==1&&j+l>=2) an.s+=aa.rz[i][j]*bb.lz[k][l];
	return an;
}
int n,a[N];
void psup(int o){s[o]=merg(s[o<<1],s[o<<1|1]);}
void modif(int o,int l,int r,int lx)
{
	if(l==r){a[l]^=1;s[o]=node(a[l]);return;}
	int mid=(l+r)>>1;
	if(lx<=mid) modif(o<<1,l,mid,lx);
	else modif(o<<1|1,mid+1,r,lx);
	psup(o);
}
node quer(int o,int l,int r,int ll,int rr)
{
	if(ll<=l&&r<=rr) return s[o];
	int mid=(l+r)>>1;
	if(rr<=mid) return quer(o<<1,l,mid,ll,rr);
	if(ll>mid) return quer(o<<1|1,mid+1,r,ll,rr);
	return merg(quer(o<<1,l,mid,ll,mid),quer(o<<1|1,mid+1,r,mid+1,rr));
}
void bui(int o,int l,int r)
{
	if(l==r){s[o]=node(a[l]);return;}
	int mid=(l+r)>>1;
	bui(o<<1,l,mid),bui(o<<1|1,mid+1,r);
	psup(o);
}

int main()
{
	n=rd();
	for(int i=1;i<=n;++i) a[i]=rd();
	bui(1,1,n);
	int q=rd();
	while(q--)
	{
		int op=rd();
		if(op==1) modif(1,1,n,rd());
		else
		{
			int l=rd(),r=rd();
			printf("%lld\n",1ll*(r-l+1)*(r-l+2)/2-quer(1,1,n,l,r).s);
		}
	}
	return 0;
}
posted @ 2019-10-13 22:43  ✡smy✡  阅读(184)  评论(0编辑  收藏  举报