[维修数列]( Splay树 )

From:http://www.lydsy.com/JudgeOnline/problem.php?id=1500

Solution:

BT的数据结构。一个技巧, 插入两个哑对象, 把所有操作都转换成"Root的右孩子的左子女"形式来维护。即对于区间[a,b], 把a-1旋转到Root, 再把b+1旋转到Root的右子女, 那么区间[a,b]就是Root右子女的左子树。

关于最大字段和的维护, 类比线段树来凑出所有情况。

关于翻转操作, 先翻转再旋转 和 先旋转再翻转 是等价操作(随便画下图就知道了)。

/**************************************************************
    Problem: 1500
    User: leezy
    Language: C++
    Result: Accepted
    Time:5252 ms
    Memory:26668 kb
****************************************************************/
 
#include <iostream>
#include <stdio.h>
#include <string.h>
#include <cmath>
using namespace std;
#define N 500015
#define M 1000000
#define INF (1<<30)
int cnt, root, cnt2, ar[N],ar2[N];
struct SPLAY{ int v,l,r,p,s,same,rev,sum,Lmax,Rmax,Max; } t[N];
#define KT t[root].r
int new_node(int v){
    int rr = 0;
    if ( cnt2 ) rr = ar2[cnt2--];
    else rr = ++cnt;
    t[rr].v = v; t[rr].sum = v; t[rr].Max = v;
    t[rr].l = t[rr].r = t[rr].p = 0;
    t[rr].s = 1; t[rr].same = t[rr].rev = 0;
    t[rr].Lmax = v; t[rr].Rmax = v;
    return rr;
}
void upd_same(int x, int v){
    if ( x ){
        t[x].v = v; t[x].same = 1;
        t[x].sum = t[x].s * v;
        t[x].Lmax = t[x].Rmax = t[x].Max = max(t[x].sum, t[x].v);
    }
}
void upd_rev(int x){
    if ( x ){
        swap(t[x].l, t[x].r);
        swap(t[x].Lmax, t[x].Rmax);
        t[x].rev ^= 1; //!!这里导致wa无数..
    }
}
void upd_down(int x){
    if ( t[x].same ){
       upd_same(t[x].l, t[x].v);
       upd_same(t[x].r, t[x].v);
       t[x].same = 0;
    }
    if ( t[x].rev ){
        upd_rev(t[x].l);
        upd_rev(t[x].r);
        t[x].rev = 0;
    }
}
void upd_up(int x){
    if ( x ) {
        t[x].s = t[t[x].l].s + t[t[x].r].s + 1;
        t[x].sum = t[t[x].l].sum + t[t[x].r].sum + t[x].v;
        t[x].Lmax = max(t[t[x].l].Lmax, t[t[x].l].sum + t[x].v + max(0, t[t[x].r].Lmax));
        t[x].Rmax = max(t[t[x].r].Rmax, t[t[x].r].sum + t[x].v + max(0, t[t[x].l].Rmax));
        int Lmax = max(0, t[t[x].r].Lmax);
        int Rmax = max(0, t[t[x].l].Rmax);
        int Max1 = max(t[t[x].l].Max, t[t[x].r].Max);
        t[x].Max = max(t[x].v, Rmax + Lmax + t[x].v);
        t[x].Max = max(Max1, t[x].Max);
    }
}
int build(int l, int r){
    if ( l <= r ){
        int mid = (l+r)/2;
        int k = new_node(ar[mid]);
        int ll = build(l, mid-1);
        int rr = build(mid+1, r);
        t[k].l = ll; t[k].r = rr;
        if ( ll ) t[ll].p = k;
        if ( rr ) t[rr].p = k;
        upd_up(k);
        return k;
    }
    return 0;
}
void rot_l(int x){
    int y = t[x].r;
    upd_down(y); upd_down(x);
    t[x].r = t[y].l; t[t[y].l].p = x;
    t[y].p = t[x].p;
    if ( !t[x].p ) root = y;
    else if ( t[t[x].p].l == x ) t[t[x].p].l = y;
    else t[t[x].p].r = y;
    t[y].l = x; t[x].p = y;
    upd_up(x); upd_up(y);
}
void rot_r(int x){
    int y = t[x].l;
    upd_down(y); upd_down(x);
    t[x].l = t[y].r; t[t[y].r].p = x;
    t[y].p = t[x].p;
    if ( !t[x].p ) root = y;
    else if ( t[t[x].p].l == x ) t[t[x].p].l = y;
    else t[t[x].p].r = y;
    t[y].r = x; t[x].p = y;
    upd_up(x); upd_up(y);
}
void splay(int x, int to){
    upd_down(x);
    while ( t[x].p != to ){
        int p1 = t[x].p;
        if ( t[p1].l == x ){
            int p2 = t[p1].p;
            if ( to == p2 ) rot_r(p1);
            else if ( t[p2].l == p1 ) rot_r(p2), rot_r(p1);
            else rot_r(p1), rot_l(p2);
        } else {
            int p2 = t[p1].p;
            if ( to == p2 ) rot_l(p1);
            else if ( t[p2].r == p1 ) rot_l(p2), rot_l(p1);
            else rot_l(p1), rot_r(p2);
        }
    }
    upd_up(x);
}
int get_kth(int x, int k){
    upd_down(x);
    int kk = t[t[x].l].s + 1;
    if ( kk == k ) return x;
    if ( k < kk ) return get_kth(t[x].l, k);
    return get_kth(t[x].r, k - kk);
}
void remove(int x){
    if ( x ){
        ar2[++cnt2] = x;
        remove(t[x].l); remove(t[x].r);
    }
}
void ins(int k, int tot){
    int x = get_kth(root, k+1);
    splay(x, 0);
    int y = get_kth(root, k+2);
    splay(y, x);
    int tt = build(0, tot-1);
    t[KT].l = tt; t[tt].p = KT;
    upd_up(KT); upd_up(root);
}
void del(int k, int tot){
    int x = get_kth(root, k);
    splay(x, 0);
    int y = get_kth(root, k+tot+1);
    splay(y, x);
    remove(t[KT].l);
    t[KT].l = 0;
    upd_up(KT); upd_up(root);
}
void same(int k, int tot, int v){
    int x = get_kth(root, k);
    splay(x, 0);
    int y = get_kth(root, k+tot+1);
    splay(y, x);
    upd_same(t[KT].l, v);
}
void rev(int k, int tot){
    //printf("Reverse\n");
    int x = get_kth(root, k);
    splay(x, 0);
    int y = get_kth(root, k+tot+1);
    splay(y, x);
    upd_rev(t[KT].l);
}
int get_sum(int k, int tot){
    int x = get_kth(root, k);
    splay(x, 0);
    int y = get_kth(root, k+tot+1);
    splay(y, x);
    return t[t[KT].l].sum;
}
int get_maxsum(){
    int x = get_kth(root, 1);
    splay(x, 0);
    int y = get_kth(root, t[root].s);
    splay(y, x);
    return t[t[KT].l].Max;
}
void init(int n){
    cnt = cnt2 = root = 0;
    t[0].v = t[0].l = t[0].r = t[0].p = t[0].s = t[0].sum = 0;
    t[0].Max = t[0].Lmax = t[0].Rmax = -INF;
    t[0].same = t[0].rev = 0;
    for ( int i = 0; i < n; ++i ){
        scanf("%d", &ar[i]);
    }
    root = new_node(-INF);
 
    int x = new_node(-INF);
    t[root].r = x; t[x].p = root;
    upd_up(root);
 
    int z = build( 0, n-1 );
    t[KT].l = z; t[z].p = KT;
    upd_up(KT); upd_up(root);
}
int main()
{
 
    int n,m;
    while ( scanf("%d%d", &n, &m) != EOF ){
        int k, tot;
        char s[20];
        init(n);
        for ( int i = 0; i < m; ++i ){
            scanf("%s", s);
            if ( s[0] == 'G' ){
                scanf("%d%d", &k, &tot);
                printf("%d\n", get_sum(k, tot));
            } else if ( s[0] == 'M' ){
                if ( s[2] == 'X'){
                    printf("%d\n", get_maxsum());
                } else {
                    int v;
                    scanf("%d%d%d", &k, &tot, &v);
                    same(k, tot, v);
                }
            } else if ( s[0] == 'I' ){
                scanf("%d%d", &k, &tot);
                for ( int j = 0; j < tot; ++j ) scanf("%d", ar+j);
                ins(k, tot);
            } else if ( s[0] == 'D' ){
                scanf("%d%d", &k, &tot);
                del(k, tot);
            } else if ( s[0] == 'R' ){
                scanf("%d%d", &k, &tot);
                rev(k, tot);
            }
        }
    }
    return 0;
}

 

posted on 2013-11-24 13:50  leezyli  阅读(215)  评论(0编辑  收藏  举报

导航