Splay - POJ 3481 - Double Queue

Splay - POJ 3481 - Double Queue

Splay模板题,实现插入、删除、前驱后继函数。加入左右哨兵可以更加方便地查询最值。

#include <bits/stdc++.h>
using namespace std;

const int N = 1e6+5;
const int INF = 0x3fffffff;
struct Node{
    int s[2],p,v;
    int id;
    int size;
    void init(int _v,int _id,int _p){
        v = _v;
        id = _id;
        p = _p;
    }
}tr[N];

int root,idx;

inline int ws(int x){
    return tr[tr[x].p].s[1] == x;
}

inline void push_up(int x){
    tr[x].size = tr[tr[x].s[0]].size + tr[tr[x].s[1]].size + 1;
}

inline void rotate(int x){
    int y = tr[x].p;
    int z = tr[y].p;
    int k = ws(x);
    
    // modify z
    tr[z].s[ws(y)] = x;
    
    // modify y
    tr[y].p = x;
    tr[y].s[k] = tr[x].s[k^1];
    
    // modify m
    tr[tr[x].s[k^1]].p = y;
    
    // modify x
    tr[x].p = z;
    tr[x].s[k^1] = y;
    
    push_up(y);
    push_up(x);
}

inline void splay(int x,int k){
    while(tr[x].p != k){
        int y = tr[x].p;
        int z = tr[y].p;
        if(z != k){
            if(ws(x) ^ ws(y)){
                rotate(x);
            }else{
                rotate(y);
            }
        }
        rotate(x);
    }
    if(!k) root = x;
}

inline int insert(int v,int id){
    int p = 0;
    int u = root;
    while(u){
        p = u;
        if(v > tr[u].v){
            u = tr[u].s[1];
        }else{
            u = tr[u].s[0];
        }
    }
    u = ++idx;
    if(p) tr[p].s[v > tr[p].v] = u;
    tr[u].init(v, id, p);
    splay(u, 0);
    return u;
}

int get_pre(int x){
    splay(x, 0);
    int u = tr[x].s[0];
    while(tr[u].s[1]) u = tr[u].s[1];
    return u;
}

int get_suc(int x){
    splay(x, 0);
    int u = tr[x].s[1];
    while(tr[u].s[0]) u = tr[u].s[0];
    return u;
}

void remove(int x){
    int l = get_pre(x);
    int r = get_suc(x);
    splay(l, 0);
    splay(r, l);
    tr[r].s[0] = 0;
    push_up(r);
    push_up(l);
    
}

int op,id,v;

int main(){
    int L = insert(-INF,0);
    int R = insert(INF,0);
    
    while(scanf("%d",&op) != EOF){
        if(op == 1){
            scanf("%d%d",&id,&v);
            insert(v, id);
        }else if(op == 2){
            if(tr[root].size == 2) {
                puts("0");
                continue;
            }
            int x = get_pre(R);
            printf("%d\n",tr[x].id);
            remove(x);
        }else if(op == 3){
            if(tr[root].size == 2) {
                puts("0");
                continue;
            }
            int x = get_suc(L);
            printf("%d\n",tr[x].id);
            remove(x);
        }else{
            break;
        }
    }
    
    return 0;
}

posted @ 2021-02-20 16:21  popozyl  阅读(43)  评论(0编辑  收藏  举报