bzoj 5294: [Bjoi2018]二进制【动态dp+线段树】

不太清楚是不是动态dp……?
这个维护其实和最大连续子段差不多,维护l[x][y],r[x][y],m[x][y]分别表示包含左儿子的01个数为(x,y)的区间个数,包含右儿子的01个数为(x,y)的区间个数,和01个数为(x,y)的所有区间个数
x表示1的个数情况,0表示0个,1表示1个,2表示>=2的偶数个,3表示>=3的奇数个
y表示0的个数情况,0表示0个,1表示1个,2表示>=2个
转移的话合并ls.r,rs.l即可,注意是乘法,注意细节,转移很难写……

#include<iostream>
#include<cstdio>
using namespace std;
const int N=100005;
int n,m,a[N];
struct xds
{
	long long l[4][3],r[4][3],m[4][3],s[2];
}t[N<<2];
int read()
{
	int r=0,f=1;
	char p=getchar();
	while(p>'9'||p<'0')
	{
		if(p=='-')
			f=-1;
		p=getchar();
	}
	while(p>='0'&&p<='9')
	{
		r=r*10+p-48;
		p=getchar();
	}
	return r*f;
}
int wk1(int x) 
{
    return (x<=1)?x:(x%2+2);
}
 
int wk0(int x) 
{
    return (x<=1)?x:2;
}
xds operator + (const xds &a,const xds &b)
{
	xds c;
	c.s[0]=a.s[0]+b.s[0];
	c.s[1]=a.s[1]+b.s[1];
	for(int i=0;i<=3;i++)
		for(int j=0;j<=2;j++)
		{
			c.l[i][j]=a.l[i][j];
			c.r[i][j]=b.r[i][j];
			c.m[i][j]=a.m[i][j]+b.m[i][j];
		}
	for(int i=0;i<=3;i++)
		for(int j=0;j<=2;j++)
			if(a.r[i][j])
				for(int k=0;k<=3;k++)
					for(int l=0;l<=2;l++)
						if(b.l[k][l])
							c.m[wk1(i+k)][wk0(j+l)]+=a.r[i][j]*b.l[k][l];
	for(int i=0;i<=3;i++)
		for(int j=0;j<=2;j++)
		{
			c.l[wk1(a.s[1]+i)][wk0(a.s[0]+j)]+=b.l[i][j];
			c.r[wk1(b.s[1]+i)][wk0(b.s[0]+j)]+=a.r[i][j];
		}
	return c;
}
void build(int ro,int l,int r)
{
	if(l==r)
	{
		int x=(a[l]==1),y=(a[l]==0);
        t[ro].s[0]=y,t[ro].s[1]=x;
        t[ro].m[x][y]=t[ro].l[x][y]=t[ro].r[x][y]=1;
		return;
	}
	int mid=(l+r)>>1;
	build(ro<<1,l,mid);
	build(ro<<1|1,mid+1,r);
	t[ro]=t[ro<<1]+t[ro<<1|1];
}
void update(int ro,int l,int r,int p)
{
	if(l==r)
	{
		int x=t[ro].s[1],y=t[ro].s[0];
        t[ro].m[x][y]=t[ro].l[x][y]=t[ro].r[x][y]=0;
        swap(t[ro].s[0],t[ro].s[1]);
        t[ro].m[y][x]=t[ro].l[y][x]=t[ro].r[y][x]=1;
		return;
	}
	int mid=(l+r)>>1;
	if(p<=mid)
		update(ro<<1,l,mid,p);
	else
		update(ro<<1|1,mid+1,r,p);
	t[ro]=t[ro<<1]+t[ro<<1|1];
}
xds ques(int ro,int l,int r,int x,int y)
{
	if(l==x&&r==y)
		return t[ro];
	int mid=(l+r)>>1;
	if(y<=mid)
		return ques(ro<<1,l,mid,x,y);
	else if(x>mid)
		return ques(ro<<1|1,mid+1,r,x,y);
	else
		return ques(ro<<1,l,mid,x,mid)+ques(ro<<1|1,mid+1,r,mid+1,y);
}
int main()
{
	n=read();
	for(int i=1;i<=n;i++)
		a[i]=read();
	build(1,1,n);
	m=read();
	while(m--)
	{
		int o=read();
		if(o==1)
		{
			int x=read();
			update(1,1,n,x);
		}
		else
		{
			int l=read(),r=read();
			xds x=ques(1,1,n,l,r);
			printf("%lld\n",x.m[0][0]+x.m[0][1]+x.m[0][2]+x.m[2][0]+x.m[2][1]+x.m[2][2]+x.m[3][2]);
		}
	}
	return 0;
}
posted @ 2018-12-10 07:57  lokiii  阅读(336)  评论(1编辑  收藏  举报