hdu 3397 Sequence operation(线段树的延迟标记)

做这道题之前建议先做:hdu 3911是这道题的一部分,这是我的博客链接:http://www.cnblogs.com/jiangjing/archive/2013/01/16/2863266.html

题意:首先给出一组数据:由0和1组成,然后有5种操作,0 a b表示把[a,b]区间的数全部变成0;1 a b 表示把[a,b]区间的数全部变成1;2 a b表示把[a,b]区间的0变成1、1变成0,也就是进行异或操作;3 a b就是问你[a,b]区间总共有多少个1,;4 a b就是问你[a,b]区间最长的连续的1的个数。

代码实现:

#include<iostream>
using namespace std;
struct node{
    int l,r;
    int lone,lzero;//分别是从左数1的个数和0的个数
    int rone,rzero;//分别是从右数1的个数和0的个数
    int max1,max0;//分别是连续的最多的1的个数和连续的最长0的个数
    int total1,flag0,flag1,flag2,mlen;//分别是总共1的个数、是否进行了0操作,1操作,2操作的标记
}p[1000001];//这里有点坑开了100001*4的大小一直是Runtime erroy
int a[1000001];
int max(int x,int y)
{
    return x>y?x:y;
}
int min(int x,int y)
{
    return x<y?x:y;
}
void pushup(int n)//往上更新
{
    p[n].lone=p[n*2].lone;
    if(p[n*2].lone==p[n*2].mlen)
        p[n].lone+=p[n*2+1].lone;
    p[n].lzero=p[n*2].lzero;
    if(p[n*2].lzero==p[n*2].mlen)//左右子树可以合并
        p[n].lzero+=p[n*2+1].lzero;
    p[n].rone=p[n*2+1].rone;
    if(p[n*2+1].rone==p[n*2+1].mlen)//左右子树可以合并
        p[n].rone+=p[n*2].rone;
    p[n].rzero=p[n*2+1].rzero;
    if(p[n*2+1].rzero==p[n*2+1].mlen)//左右子树可以合并
        p[n].rzero+=p[n*2].rzero;
    p[n].max1=max(p[n*2].max1,p[n*2+1].max1);
    p[n].max1=max(p[n].max1,p[n*2].rone+p[n*2+1].lone);
    p[n].max0=max(p[n*2].max0,p[n*2+1].max0);
    p[n].max0=max(p[n].max0,p[n*2].rzero+p[n*2+1].lzero);
    p[n].total1=p[n*2].total1+p[n*2+1].total1;
}
void build(int l,int r,int n)//建立线段树
{
    p[n].l=l;
    p[n].r=r;
    p[n].flag0=p[n].flag1=p[n].flag2=0;
    p[n].mlen=(r-l+1);
    if(l==r)
    {
        if(a[l]==1)
        {
            p[n].total1=1;
            p[n].lone=1;
            p[n].lzero=0;
            p[n].rone=1;
            p[n].rzero=0;
            p[n].max1=1;
            p[n].max0=0;
        }
        else
        {
            p[n].total1=0;
            p[n].lone=0;
            p[n].lzero=1;
            p[n].rone=0;
            p[n].rzero=1;
            p[n].max1=0;
            p[n].max0=1;
        }
        return ;
    }
    int mid=(l+r)/2;
    build(l,mid,n*2);
    build(mid+1,r,n*2+1);
    pushup(n);//往上更新
}
void pushdown(int n,int flag)//往下更新
{
    if(flag==0)
    {
        p[n*2].flag0=p[n*2+1].flag0=1;
        p[n*2].flag1=p[n*2].flag2=0;//flag1和flag2不管以前是0还是1全部变成0
        p[n*2+1].flag1=p[n*2+1].flag2=0;
        p[n].flag0=0;
        p[n*2].total1=0;
        p[n*2].lone=0;
        p[n*2].lzero=p[n*2].mlen;
        p[n*2].rone=0;
        p[n*2].rzero=p[n*2].mlen;
        p[n*2].max1=0;
        p[n*2].max0=p[n*2].mlen;

        p[n*2+1].total1=0;
        p[n*2+1].lone=0;
        p[n*2+1].lzero=p[n*2+1].mlen;
        p[n*2+1].rone=0;
        p[n*2+1].rzero=p[n*2+1].mlen;
        p[n*2+1].max1=0;
        p[n*2+1].max0=p[n*2+1].mlen;
    }
    else if(flag==1)
    {
        p[n*2].flag1=p[n*2+1].flag1=1;
        p[n*2].flag0=p[n*2].flag2=0;//flag0和flag2不管以前是0还是1全部变成0
        p[n*2+1].flag0=p[n*2+1].flag2=0;
        p[n].flag1=0;
        p[n*2].total1=p[n*2].mlen;
        p[n*2].lone=p[n*2].mlen;
        p[n*2].lzero=0;
        p[n*2].rone=p[n*2].mlen;
        p[n*2].rzero=0;
        p[n*2].max1=p[n*2].mlen;
        p[n*2].max0=0;

        p[n*2+1].total1=p[n*2+1].mlen;
        p[n*2+1].lone=p[n*2+1].mlen;
        p[n*2+1].lzero=0;
        p[n*2+1].rone=p[n*2+1].mlen;
        p[n*2+1].rzero=0;
        p[n*2+1].max1=p[n*2+1].mlen;
        p[n*2+1].max0=0;
    }
    else
    {
        p[n*2].flag2^=1;
        p[n*2+1].flag2^=1;
        p[n].flag2=0;
        swap(p[n*2].lone,p[n*2].lzero);
        swap(p[n*2].rone,p[n*2].rzero);
        swap(p[n*2].max1,p[n*2].max0);
        p[n*2].total1=p[n*2].mlen-p[n*2].total1;

        swap(p[n*2+1].lone,p[n*2+1].lzero);
        swap(p[n*2+1].rone,p[n*2+1].rzero);
        swap(p[n*2+1].max1,p[n*2+1].max0);
        p[n*2+1].total1=p[n*2+1].mlen-p[n*2+1].total1;
    }
}
void insert(int x,int y,int n,int nima)//更新
{
    if(x==p[n].l&&p[n].r==y)
    {
        if(nima==0)
        {
                p[n].flag0=1;
                p[n].flag1=0;
                p[n].flag2=0;
                p[n].lone=0;
                p[n].lzero=p[n].mlen;
                p[n].rone=0;
                p[n].rzero=p[n].mlen;
                p[n].max1=0;
                p[n].max0=p[n].mlen;
                p[n].total1=0; 
                return ;
        }
        else if(nima==1)
        {
                p[n].flag1=1;
                p[n].flag0=0;
                p[n].flag2=0;
                p[n].lone=p[n].mlen;
                p[n].lzero=0;
                p[n].rone=p[n].mlen;
                p[n].rzero=0;
                p[n].max1=p[n].mlen;
                p[n].max0=0;
                p[n].total1=p[n].mlen; 
                return ;
        }
        else//如果是异或的话,先判断下flag0和flag1是0还是1
        {
            p[n].flag2=p[n].flag2^1;
            if(p[n].flag0==1)//是1的话先往下更新
                pushdown(n,0);
            if(p[n].flag1==1)//是1的话先往下更新
                pushdown(n,1);
            swap(p[n].lone,p[n].lzero);
            swap(p[n].rone,p[n].rzero);
            swap(p[n].max1,p[n].max0);
            p[n].total1=p[n].mlen-p[n].total1;
            return ;
        }
    }
    if(p[n].l!=p[n].r)
    {
        if(p[n].flag0==1)
          pushdown(n,0);
        if(p[n].flag1==1)
          pushdown(n,1);
        if(p[n].flag2==1)
          pushdown(n,2);
    }
    int mid=(p[n].l+p[n].r)/2;
    if(y<=mid)
        insert(x,y,n*2,nima);
    else if(x>mid)
        insert(x,y,n*2+1,nima);
    else
    {
        insert(x,mid,n*2,nima);
        insert(mid+1,y,n*2+1,nima);
    }
    pushup(n);
}
int sum1(int x,int y,int n)//求区间总共有多少个1
{
    if(p[n].mlen==p[n].max1)
        return y-x+1;
    if(p[n].mlen==p[n].max0)
        return 0;
    if(x==p[n].l&&y==p[n].r)
        return p[n].total1;
    if(p[n].l!=p[n].r)
    {
        if(p[n].flag0==1)
          pushdown(n,0);
        if(p[n].flag1==1)
          pushdown(n,1);
        if(p[n].flag2==1)
          pushdown(n,2);
    }
    int mid=(p[n].l+p[n].r)/2;
    if(y<=mid)
        sum1(x,y,n*2);
    else if(x>mid)
        sum1(x,y,n*2+1);
    else 
        return sum1(x,mid,n*2)+sum1(mid+1,y,n*2+1);
}
int sum2(int x,int y,int n)//求区间连续1的最长的个数
{
    if(p[n].mlen==p[n].max1)
        return y-x+1;
    if(p[n].mlen==p[n].max0)
        return 0;
    if(x==p[n].l&&y==p[n].r)
       return p[n].max1;
    if(p[n].l!=p[n].r)
    {
        if(p[n].flag0==1)
          pushdown(n,0);
        if(p[n].flag1==1)
          pushdown(n,1);
        if(p[n].flag2==1)
          pushdown(n,2);
    }
    int mid=(p[n].l+p[n].r)/2;
    if(y<=mid)
        return sum2(x,y,n*2);
    else if(x>mid)
        return sum2(x,y,n*2+1);
    else
    {
        int left=0,right=0,midden=0;
        midden=midden+min(mid-x+1,p[n*2].rone)+min(y-mid,p[n*2+1].lone);
        left=sum2(x,mid,n*2);
        right=sum2(mid+1,y,n*2+1);
        return max(midden,max(left,right));
    }
}
int main()
{
    int T,n,m,i,nima,x,y;
    while(scanf("%d",&T)!=EOF)
    {
        while(T--)
        {
            scanf("%d%d",&n,&m);
            for(i=1;i<=n;i++)
                scanf("%d",&a[i]);
            build(1,n,1);
            while(m--)
            {
               scanf("%d%d%d",&nima,&x,&y);
               switch(nima)
               {
                 case 0:
                     insert(x+1,y+1,1,nima);break;
                 case 1:
                     insert(x+1,y+1,1,nima);break;
                 case 2:
                     insert(x+1,y+1,1,nima);break;
                 case 3:
                     printf("%d\n",sum1(x+1,y+1,1));break;
                 case 4:
                     printf("%d\n",sum2(x+1,y+1,1));break;     
               }
            }
        }
    }
    return 0;
}

 

 

posted on 2013-01-18 13:01  后端bug开发工程师  阅读(975)  评论(0编辑  收藏  举报

导航