[NOI2005]维护数列

一道splay综合大板子题。

题面:https://www.lydsy.com/JudgeOnline/problem.php?id=1500

下面是题解:

首先对每个点维护这些量:

1.两个儿子(ch[2])

2.父节点(fa)

3.当前点权值(vl)和子树权值(sum)

4.修改标记(xg),反转标记(fz)。

5.mx,mxl,mxr(不清楚的建议做小白逛公园)。

题中还有一个条件,即:

100%的数据中,任何时刻数列中最多含有 500 000 个数。

这就解决了空间问题。我们可以用一个队列记录有多少点可以回收,这样就可以节省大量的空间。

最后一点:本题中涉及到max值,区间反转后需要update。

代码:

#include<queue>
#include<cstdio>
#include<algorithm>
using namespace std;
#define N 1000050
int n,m,a[N],rt,tot,pos,cnt,id[N];
queue<int>que;
char ch[15];
struct Splay
{
    int ch[2];
    int xg,fz,fa;
    int vl,sum,siz,mx,mxl,mxr;
}tr[N];
void update(int u)
{
    int l= tr[u].ch[0] , r = tr[u].ch[1];
    tr[u].sum = tr[l].sum+tr[r].sum+tr[u].vl;
    tr[u].siz = tr[l].siz+tr[r].siz+1;
    tr[u].mx = max(max(tr[l].mx,tr[r].mx),tr[l].mxr+tr[r].mxl+tr[u].vl);
    tr[u].mxl = max(tr[l].mxl,tr[l].sum+tr[u].vl+tr[r].mxl);
    tr[u].mxr = max(tr[r].mxr,tr[r].sum+tr[u].vl+tr[l].mxr);
}
void pushdown(int u)
{
    int l = tr[u].ch[0],r = tr[u].ch[1];
    if(tr[u].xg)
    {
        tr[u].xg = tr[u].fz = 0;
        if(l)tr[l].xg = 1,tr[l].vl = tr[u].vl,tr[l].sum = tr[l].siz*tr[l].vl;
        if(r)tr[r].xg = 1,tr[r].vl = tr[u].vl,tr[r].sum = tr[r].siz*tr[r].vl;
        if(tr[u].vl>0)
        {
            if(l)tr[l].mx=tr[l].mxl=tr[l].mxr=tr[l].sum;
            if(r)tr[r].mx=tr[r].mxl=tr[r].mxr=tr[r].sum;
        }else
        {
            if(l)tr[l].mx=tr[l].vl,tr[l].mxl=tr[l].mxr=0;
            if(r)tr[r].mx=tr[r].vl,tr[r].mxl=tr[r].mxr=0;
        }
    }
    if(tr[u].fz)
    {
        tr[u].fz = 0;
        tr[l].fz^=1,tr[r].fz^=1;
        swap(tr[l].ch[0],tr[l].ch[1]);
        swap(tr[r].ch[0],tr[r].ch[1]);
        swap(tr[l].mxl,tr[l].mxr);
        swap(tr[r].mxl,tr[r].mxr);
    }
}
void rotate(int x)
{
    int y = tr[x].fa;
    int z = tr[y].fa;
    int k = (tr[y].ch[1]==x);
    tr[tr[x].ch[k^1]].fa = y,tr[y].ch[k] = tr[x].ch[k^1];
    tr[x].ch[k^1] = y,tr[y].fa = x;
    tr[x].fa = z,tr[z].ch[tr[z].ch[1]==y]=x;
    update(y);update(x);
}
void splay(int u,int goal)
{
    while(tr[u].fa!=goal)
    {
        int y = tr[u].fa;
        int z = tr[y].fa;
        if(z!=goal)
            ((tr[y].ch[1]==u)^(tr[z].ch[1]==y))?rotate(u):rotate(y);
        rotate(u);
    }
    if(!goal)rt=u;
}
void build(int l,int r,int f)
{
    if(l>r)return ;
    int mid = (l+r)>>1,u = id[mid],fa = id[f];
    if(l==r)
    {
        tr[u].siz=1;
        if(a[l]>0)tr[u].mx=tr[u].mxl=tr[u].mxr=a[l];
        else tr[u].mx=a[l],tr[u].mxl=tr[u].mxr=0;
    }else
    {
        build(l,mid-1,mid);
        build(mid+1,r,mid);
    }
    tr[u].vl = a[mid];
    tr[u].fa = fa;
    update(u);
    tr[fa].ch[mid>=f] = u;
}
int find(int x,int k)
{
    pushdown(x);
    int t = tr[tr[x].ch[0]].siz;
    if(k<=t)return find(tr[x].ch[0],k);
    else if(k==t+1)return x;
    else return find(tr[x].ch[1],k-1-t);
}
void insert(int k,int tt)
{
    for(int i=1;i<=tt;i++)
    {
        scanf("%d",&a[i]);
    }
    for(int i=1;i<=tt;i++)
    {
        if(!que.empty())
        {
            id[i]=que.front();
            que.pop();
        }else
        {
            id[i]=++cnt;
        }
    }
    build(1,tt,0);
    int rt0 = id[(1+tt)>>1];
    int l = find(rt,k+1);
    int r = find(rt,k+2);
    splay(l,0);
    splay(r,l);
    tr[r].ch[0]=rt0;
    tr[rt0].fa=r;
    update(r);
    update(rt);
}
void rip(int x)
{
    if(!x)return ;
    int l = tr[x].ch[0],r = tr[x].ch[1];
    rip(l);
    rip(r);
    que.push(x);
    tr[x].xg=tr[x].fz=tr[x].vl=tr[x].sum=0;
    tr[x].siz=tr[x].mx=tr[x].mxl=tr[x].mxr=tr[x].ch[0]=tr[x].ch[1]=0;
}
int deal(int l,int r)
{
    l = find(rt,l);
    r = find(rt,r);
    splay(l,0);
    splay(r,l);
    return tr[r].ch[0];
}
void erase(int l,int r)
{
    int x = deal(l,r),y=tr[x].fa;
    rip(x);
    tr[y].ch[0]=0;
    update(y);
    update(rt);
}
void make_same(int l,int r,int k)
{
    int x = deal(l,r),y=tr[x].fa;
    tr[x].vl=k,tr[x].xg=1;
    tr[x].sum = tr[x].siz*k;
    if(k>0)tr[x].mx=tr[x].mxl=tr[x].mxr=tr[x].sum;
    else tr[x].mx=k,tr[x].mxl=tr[x].mxr=0;
    update(y);
    update(rt);
}
void rever(int l,int r)
{
    int x = deal(l,r);
    if(!tr[x].xg)
    {
        tr[x].fz^=1;
        swap(tr[x].ch[0],tr[x].ch[1]);
        swap(tr[x].mxl,tr[x].mxr);
        update(tr[x].fa),update(rt);
    }
}
int get_sum(int l,int r)
{
    int x = deal(l,r);
    return tr[x].sum;
}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)scanf("%d",&a[i+1]);
    for(int i=1;i<=n+2;i++)id[i]=i;
    tr[0].mx=a[1]=a[n+2]=-0x3f3f3f3f;
    build(1,n+2,0);
    cnt = n+2;
    rt = (n+3)>>1;
    for(int k,i=1;i<=m;i++)
    {
        scanf("%s",ch);
        if(ch[0]=='I')
        {
            scanf("%d%d",&pos,&tot);
            insert(pos,tot);
        }else if(ch[0]=='D')
        {
            scanf("%d%d",&pos,&tot);
            erase(pos,pos+tot+1);
        }else if(ch[2]=='K')
        {
            scanf("%d%d%d",&pos,&tot,&k);
            make_same(pos,pos+tot+1,k);
        }else if(ch[0]=='R')
        {
            scanf("%d%d",&pos,&tot);
            rever(pos,pos+tot+1);
        }else if(ch[0]=='G')
        {
            scanf("%d%d",&pos,&tot);
            printf("%d\n",get_sum(pos,pos+tot+1));
        }else
        {
            printf("%d\n",tr[rt].mx);
        }
    }
    return 0;
}

 

posted @ 2018-09-08 01:47  LiGuanlin  阅读(330)  评论(1编辑  收藏  举报