luogu P4428 [BJOI2018]二进制
先考虑怎样的二进制串才会被3整除.可以发现如果二进制位第\(0,2,4...2n\)位如果为\(1\),那么在模3意义下为1,如果二进制位第\(1,3,5...2n+1\)位如果为\(1\),那么在模3意义下为-1.所以也就是位置上是1的奇二进制位个数减位置上是1的偶二进制位个数要被3整除
在这种条件下,如果区间内1的个数为偶数显然可以从最低位开始依次放使得被3整除,如果为奇数,那么先把除了最后三个1以外的1按照偶数的情况处理,然后这三个1中间各插入一个0,也就是\(...0101011...1\).那么,不合法的情况就只剩下有区间内奇数个1同时0的个数\(<2\),或者是区间内只有一个1
合法区间比较麻烦,改为求总区间个数-不合法区间个数.为了不算重,把不合法条件改为只剩下有区间内奇数个1同时0的个数\(<2\),或者是区间内只有一个1同时\(\ge 2\).我们用线段树维护这些区间个数,对每个节点记一个\(ls_{i,j}\)表示左端点为这个线段树节点对应区间左端点的区间中,1的个数奇偶性为\(0/1\),0的个数为\(0/1\)的区间个数,\(rs_{i,j}\)表示的是右端点为线段树节点右端点的相应的区间个数;\(lz_{i,j}\)表示左端点为线段树节点左端点的区间中,1的个数为\(0/1\),0的个数为\(0/1/\ge 2\)的区间个数,\(rz_{i,j}\)表示的是右端点为线段树节点右端点的相应的区间个数.以及分别记录区间\(0/1\)个数和不合法区间个数,每次合并两个节点,就计算跨越这两个节点的区间信息,可能需要一点点讨论,这里不再赘述
#include<bits/stdc++.h>
#define LL long long
#define uLL unsigned long long
#define db double
using namespace std;
const int N=1e5+10;
int rd()
{
int x=0,w=1;char ch=0;
while(ch<'0'||ch>'9'){if(ch=='-') w=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+(ch^48);ch=getchar();}
return x*w;
}
struct node
{
LL c0,c1,s;
LL ls[2][2],rs[2][2];
LL lz[2][3],rz[2][3];
void clr(){memset(ls,0,sizeof(ls)),memset(rs,0,sizeof(rs)),memset(lz,0,sizeof(lz)),memset(rz,0,sizeof(rz)),c0=c1=s=0;}
node(){}
node(int x)
{
memset(ls,0,sizeof(ls)),memset(rs,0,sizeof(rs)),memset(lz,0,sizeof(lz)),memset(rz,0,sizeof(rz)),c0=c1=s=0;
if(!x)
{
c0=1;
ls[0][1]=rs[0][1]=lz[0][1]=rz[0][1]=1;
}
else
{
s=c1=1;
ls[1][0]=rs[1][0]=lz[1][0]=rz[1][0]=1;
}
}
}s[N<<2],an;
node merg(node aa,node bb)
{
an.clr();
an.c0=aa.c0+bb.c0;
an.c1=aa.c1+bb.c1;
an.s=aa.s+bb.s;
for(int i=0;i<=1;++i)
for(int j=0;j<=1;++j)
{
an.ls[i][j]+=aa.ls[i][j];
an.rs[i][j]+=bb.rs[i][j];
if(aa.c0+j<=1) an.ls[(aa.c1&1)^i][aa.c0+j]+=bb.ls[i][j];
if(bb.c0+j<=1) an.rs[(bb.c1&1)^i][bb.c0+j]+=aa.rs[i][j];
}
for(int i=0;i<=1;++i)
for(int j=0;j<=1;++j)
for(int k=0;k<=1;++k)
for(int l=0;l<=1;++l)
if((i^k)==1&&j+l<=1) an.s+=aa.rs[i][j]*bb.ls[k][l];
for(int i=0;i<=1;++i)
for(int j=0;j<=2;++j)
{
an.lz[i][j]+=aa.lz[i][j];
an.rz[i][j]+=bb.rz[i][j];
if(aa.c1+i<=1) an.lz[aa.c1+i][min(aa.c0+j,2ll)]+=bb.lz[i][j];
if(bb.c1+i<=1) an.rz[bb.c1+i][min(bb.c0+j,2ll)]+=aa.rz[i][j];
}
for(int i=0;i<=1;++i)
for(int j=0;j<=2;++j)
for(int k=0;k<=1;++k)
for(int l=0;l<=2;++l)
if(i+k==1&&j+l>=2) an.s+=aa.rz[i][j]*bb.lz[k][l];
return an;
}
int n,a[N];
void psup(int o){s[o]=merg(s[o<<1],s[o<<1|1]);}
void modif(int o,int l,int r,int lx)
{
if(l==r){a[l]^=1;s[o]=node(a[l]);return;}
int mid=(l+r)>>1;
if(lx<=mid) modif(o<<1,l,mid,lx);
else modif(o<<1|1,mid+1,r,lx);
psup(o);
}
node quer(int o,int l,int r,int ll,int rr)
{
if(ll<=l&&r<=rr) return s[o];
int mid=(l+r)>>1;
if(rr<=mid) return quer(o<<1,l,mid,ll,rr);
if(ll>mid) return quer(o<<1|1,mid+1,r,ll,rr);
return merg(quer(o<<1,l,mid,ll,mid),quer(o<<1|1,mid+1,r,mid+1,rr));
}
void bui(int o,int l,int r)
{
if(l==r){s[o]=node(a[l]);return;}
int mid=(l+r)>>1;
bui(o<<1,l,mid),bui(o<<1|1,mid+1,r);
psup(o);
}
int main()
{
n=rd();
for(int i=1;i<=n;++i) a[i]=rd();
bui(1,1,n);
int q=rd();
while(q--)
{
int op=rd();
if(op==1) modif(1,1,n,rd());
else
{
int l=rd(),r=rd();
printf("%lld\n",1ll*(r-l+1)*(r-l+2)/2-quer(1,1,n,l,r).s);
}
}
return 0;
}