BZOJ 1588:splay tree

刚学习的splay tree。照着大神的代码敲了敲,理解了个大概

很好用的数据结构,可以用来维护数列

学习时建议先看看SBT,这样可以更好地理解“旋转”

#include"cstdio"
#include"queue"
#include"cmath"
#include"stack"
#include"iostream"
#include"algorithm"
#include"cstring"
#include"queue"
#include"map"
#include"vector"
#define ll long long
#define mems(a,b) memset(a,b,sizeof(a))

using namespace std;
const int MAXN = 1e5+50;
const int MAXE = 200500;
const int INF = 0x3f3f3f;

int pre[MAXN],val[MAXN],ch[MAXN][2];///父亲结点、值、孩子节点(0左1右)
int tot,root;///总结点数、根节点
///中序遍历为所维护数列

void newnode(int &pos,int fa,int w){
    pos=++tot;
    pre[pos]=fa;
    val[pos]=w;
    ch[pos][0]=ch[pos][1]=0;
}

void Rotate(int x,int kind){///0左1右
    int fa=pre[x];
    ch[fa][!kind]=ch[x][kind];
    pre[ch[x][kind]]=fa;

    if(pre[fa]) ch[pre[fa]][ch[pre[fa]][1]==fa]=x;
    pre[x]=pre[fa];

    ch[x][kind]=fa;
    pre[fa]=x;
}

void Splay(int r,int goal){///将r结点旋转至goal下方
    while(pre[r]!=goal){
        if(pre[pre[r]]==goal) Rotate(r,ch[pre[r]][0]==r);
        else{
            int fa=pre[r];
            int kind=ch[pre[fa]][0]==fa;
            if(ch[fa][kind]==r){///左右交替
                Rotate(r,!kind);
                Rotate(r,kind);
            }
            else{               ///方向一致
                Rotate(fa,kind);
                Rotate(r,kind);
            }
        }
    }
    if(!goal) root=r;///goal为不存在结点时,r为变为根节点
}

bool Insert(int key){
    int r=root;
    while(ch[r][val[r]<key]){
        if(val[r]==key){///相同值只插入一次
            Splay(r,0);
            return false;
        }
        r=ch[r][val[r]<key];
    }
    newnode(ch[r][val[r]<key],r,key);
    Splay(ch[r][val[r]<key],0);
    return true;
}
///BST查询复杂度log(n)
int get_pre(int x){///寻找比X大但最接近X的数
    int t=ch[x][0];
    if(!t) return INF;
    while(ch[t][1]) t=ch[t][1];
    return val[x]-val[t];
}

int get_next(int x){///寻找比X小但最接近X的数
    int t=ch[x][1];
    if(!t) return INF;
    while(ch[t][0]) t=ch[t][0];
    return val[t]-val[x];
}

int main(){
    int n;
    //freopen("in.txt","r",stdin);
    while(~scanf("%d",&n)){
        root=tot=0;
        int ans=0;
        for(int i=1;i<=n;i++){
            int x;
            if(scanf("%d",&x)==EOF) x=0;
            if(i==1){
                ans+=x;
                newnode(root,0,x);
                continue;
            }
            if(!Insert(x)) continue;

            ans+=min(get_pre(root),get_next(root));
        }
        cout<<ans<<endl;
    }
    return 0;
}

 

posted @ 2016-01-13 14:25  Septher  阅读(156)  评论(0编辑  收藏  举报