[学习笔记] 平衡树-Splay

简介

Splay是一种平衡二叉树。它通过不断地将某个节点旋转到根节点,使整棵树仍然满足二叉查找树的性质,并且保持平衡而不至于退化成链。

Splay的时间复杂度是按总复杂度来算的,具体来说,即是:
从空树开始,做插入、删除、访问操作共M次,树中最多同时存在N个点,
则总时间复杂度不超过\(O(MlogN)\)

通常取平均值,表示为单次均摊\(O(logN)\)

复杂度采用了《算法导论》中的摊还分析,可以看这篇博客的证明:https://blog.csdn.net/qq_31640513/article/details/76944892

属性

  • 二叉查找树的性质

    能够在这棵树上查找某个值的性质:左儿子的值\(<\)根节点的值\(<\)右儿子的值

  • 节点维护信息

    \[\begin{array}{llllll} \hline r t & \text {tot} & f a[i] & \operatorname{ch}[i][0 / 1] & v a l[i] & c n t[i] & s z[i] \\ \hline \end{array} \]

  int rt;//根节点
  int tot;//节点个数
  struct node {
      int fa;//父亲节点
      int ch[2];//子节点
      int val;//权值
      int cnt;//权值出现次数
      int sz;//子树大小
  };

这里为了表示每个节点的属性,采用了结构体的形式

方法

基本方法

  • maintain(x):在改变节点位置后,将节点\(x\)\(size\)更新
  • get(x):判断该节点是左儿子还是右儿子
  • Clear(x):销毁节点\(x\)
    //在改变节点位置后,将节点x的size更新
    inline void maintain(int x) {
        s[x].sz = s[s[x].ch[0]].sz+s[s[x].ch[1]].sz+s[x].cnt;
    }

    //判断该节点是左儿子还是右儿子
    inline bool get(int x) {return x == s[s[x].fa].ch[1];}

    //销毁节点x
    inline void Clear(int x) {
        s[x].ch[0] = s[x].ch[1] = s[x].fa = s[x].val = s[x].sz = s[x].cnt = 0;
    }

旋转方法

必须保证

  • 整棵树的中序遍历不变(不能破坏二叉查找树的性质)
  • 受影响的节点维护的信息依然正确有效。
  • root必须指向旋转后的根节点。

旋转分为两种:左旋右旋

具体步骤分析:

设要旋转的点是\(x\)\(x\)的父亲是\(y\)\(y\)的父亲是\(z\)

分三步:

  1. \(y\)\(x\)的子节点相连:如果\(x\)\(y\)的左儿子,那么\(x\)的右儿子与\(y\)相连
  2. \(x\)\(y\)父子相连
  3. \(x\)\(y\)的原来的父亲 \(z\)相连:如果\(y\)\(z\)的左儿子,那么\(z\)的左儿子与\(x\)相连

Rotate(x)

	inline void Rotate(int x) {
        int y = s[x].fa, z = s[y].fa, chk = get(x);

        //y与x的子节点相连
        s[y].ch[chk] = s[x].ch[chk ^ 1];
        s[s[x].ch[chk ^ 1]].fa = y;

        //x与y父子相连
        s[x].ch[chk ^ 1] = y;
        s[y].fa = x;

        // x与y的原来的父亲z相连
        s[x].fa = z;
        if(z) s[z].ch[y == s[z].ch[1]] = x;

        //只有x和y的sz变化了
        maintain(y);
        maintain(x);
    }

splay方法

每访问一个节点后都要强制将其旋转到根节点

分六种情况:

img

  • 如果\(x\)的父亲是根节点,直接将\(x\)左旋或右旋(图\(1,2\)
  • 如果\(x\)的父亲不是根节点,且\(x\)和它父亲的儿子类型(get(x)==get(f))相同,首先将其父亲左旋或右旋,然后将\(x\)右旋或左旋(图 \(3,4\)
  • 如果\(x\)的父亲不是根节点,且\(x\)和父亲的儿子类型不同,将\(x\)左旋再右旋、或者右旋再左旋(图 \(5,6\)

splay(x):复杂的过程可以转为下面简单的代码

     //将当前节点转移到根节点
    inline void splay(int x) {
        for(int f = s[x].fa; f; Rotate(x),f = s[x].fa){
            if(s[f].fa) Rotate(get(x) == get(f) ? f : x);
        }
        rt=x;
    }

因为对于当前\(x\)Rotate(x)旋转方式只有一种,如果是右儿子就左旋,左儿子就右旋,减少了很多思维上的麻烦,也就不用纠结该左还是右了。

插入方法

插入方法分三种情况

  • 如果树空了则直接插入根并退出
  • 如果原来权值存在,权值个数加一
  • 树中没有这个值,就新建节点

一定要按照二叉查找树的性质遍历树

找到了return前要进行splay()操作,来保证树的平衡

ins(k)

	//插入操作
    inline void ins(int k) {
        //如果树空了则直接插入根并退出
        if(!rt) {
            s[++tot].val = k;
            s[tot].cnt++;
            rt = tot;
            maintain(rt);
            return ;
        }
        int now = rt,f = 0;
        while(true) {
            //如果原来权值存在,权值个数加一
            if(s[now].val == k) {
                s[now].cnt++;
                maintain(now);
                maintain(f);
                splay(now);
                break;
            }
            //按照二叉查找树的性质遍历树
            f = now;
            now = s[now].ch[s[now].val < k];
            //树中没有这个值,就新建节点
            if(!now) {
                s[++tot].val = k;
                s[tot].cnt++;
                s[tot].fa = f;
                s[f].ch[s[f].val < k] = tot;
                maintain(tot);
                maintain(f);
                splay(tot);
                break;
            }
        }
    }

查询x的排名

还是按照查找二叉树的性质进行查找

  • 如果\(x\)比当前节点的权值小,向其左子树查找。
  • 如果\(x\)比当前节点的权值大,将答案加上左子树$size \(和当前节点\)cnt$的大小,向其右子树查找。
  • 如果 与当前节点的权值相同,将答案加\(1\)并返回。

Find(k)

    //查找某个数 返回这个数是第几个
    inline int Find(int k) {
        int res = 0,now = rt;
        while(true) {
            //如果这个数比当前节点小,搜索左子树
            if(k<s[now].val) {
                now = s[now].ch[0];
            }else {
                //否则加上右子树的个数
                res += s[s[now].ch[0]].sz;
                //中序遍历,如果找到这个节点返回res+1
                if(k == s[now].val) {
                    splay(now);
                    return res + 1;
                }
                res += s[now].cnt;
                now = s[now].ch[1];
            }
        }
    }

查询排名x的数

  • 如果左子树非空且剩余排名\(k\)不大于左子树的大小 $ size $,那么向左子树查找。
  • 否则将\(k\)减去左子树的和根的大小。如果此时\(k\)的值小于等于\(size\),则返回根节点的权值,否则继续向右子树查找。

getKth(k)

    //查询第k个数
    inline int getKth(int k) {
        int now = rt;
        while(true){
            if(s[now].ch[0] && k <= s[s[now].ch[0]].sz){
                now = s[now].ch[0];
            }else{
                k -= s[now].cnt + s[s[now].ch[0]].sz;
                if(k <= 0){
                    splay(now);
                    return s[now].val;
                }
                now=s[now].ch[1];
            }
        }
    }

查询前驱和后继

getPre():查询小于x的最大的数的节点,就是找左儿子的右链

getNxt():查询大于x的最小的数的节点,就是找右儿子的左链

    //查询小于x的最大的数的节点,就是找左儿子的右链
    inline int getPre() {
        int now = s[rt].ch[0];
        while (s[now].ch[1]) now = s[now].ch[1];
        return now;
    }

    //查询大于x的最小的数的节点,同理
    inline int getNxt() {
        int now = s[rt].ch[1];
        while (s[now].ch[0]) now = s[now].ch[0];
        return now;
    }

删除方法

删除方法具体步骤:

  1. 首先将\(x\)旋转到根的位置,要用到Find(x)先找到\(x\)
  2. 如果大于\(1\),则不需要删除节点,只需要将\(cnt-1\)
  3. 如果只有一个点,删除这个点之后,将rt变为\(0\)
  4. 如果左右一个儿子,就将该点删除,并让那一个儿子成为根节点
  5. 否则就将\(x\)的前驱旋转到根节点,并将\(x\)的右儿子与根节点相连,将\(x\)删除。

del(x)

    inline void del(int k){
        Find(k);//先让该点成为根节点
        if(s[rt].cnt > 1) {//如果大于1,不需要删除节点
            s[rt].cnt--;
            maintain(rt);
            return;
        }
        //如果只有一个点
        if(!s[rt].ch[0] && !s[rt].ch[1]){
            Clear(rt);
            rt = 0;
            return;
        }
        //没有左儿子,让右儿子成为根节点
        if(!s[rt].ch[0]){
            int tmp = rt;
            rt = s[rt].ch[1];
            s[rt].fa=0;
            Clear(tmp);
            return;
        }
        //没有右儿子,让左儿子成为根节点
        if(!s[rt].ch[1]){
            int tmp = rt;
            rt = s[rt].ch[0];
            s[rt].fa = 0;
            Clear(tmp);
            return;
        }
        //有左右儿子,让前驱成为根节点
        int x = getPre() , now = rt;
        splay(x);
        s[s[now].ch[1]].fa = x;
        s[x].ch[1] = s[now].ch[1];
        Clear(now);
        maintain(rt);
    }

模板题

https://www.luogu.com.cn/problem/P3369

完整代码:

#include<bits/stdc++.h>
using namespace std;
const int N = 1e5+7;

int rt;//根节点
int tot;//节点个数
struct node {
    int fa;//父亲节点
    int ch[2];//子节点
    int val;//权值
    int cnt;//权值出现次数
    int sz;//子树大小
}s[N];

struct Splay {

    //在改变节点位置后,将节点x的size更新
    inline void maintain(int x) {
        s[x].sz = s[s[x].ch[0]].sz+s[s[x].ch[1]].sz+s[x].cnt;
    }

    //判断该节点是左儿子还是右儿子
    inline bool get(int x) {return x == s[s[x].fa].ch[1];}

    //销毁节点x
    inline void Clear(int x) {
        s[x].ch[0] = s[x].ch[1] = s[x].fa = s[x].val = s[x].sz = s[x].cnt = 0;
    }

    inline void Rotate(int x) {
        int y = s[x].fa, z = s[y].fa, chk = get(x);

        //y与x的子节点相连
        s[y].ch[chk] = s[x].ch[chk ^ 1];
        s[s[x].ch[chk ^ 1]].fa = y;

        //x与y父子相连
        s[x].ch[chk ^ 1] = y;
        s[y].fa = x;

        // x与y的原来的父亲z相连
        s[x].fa = z;
        if(z) s[z].ch[y == s[z].ch[1]] = x;

        //只有x和y的sz变化了
        maintain(y);
        maintain(x);
    }
    //将当前节点转移到根节点
    inline void splay(int x) {

        for(int f = s[x].fa; f; Rotate(x),f = s[x].fa){
            if(s[f].fa) Rotate(get(x) == get(f) ? f : x);
        }
        rt=x;
    }
    //插入操作
    inline void ins(int k) {
        //如果树空了则直接插入根并退出
        if(!rt) {
            s[++tot].val = k;
            s[tot].cnt++;
            rt = tot;
            maintain(rt);
            return ;
        }
        int now = rt,f = 0;
        while(true) {
            //如果原来权值存在,权值个数加一
            if(s[now].val == k) {
                s[now].cnt++;
                maintain(now);
                maintain(f);
                splay(now);
                break;
            }
            //按照二叉查找树的性质遍历树
            f = now;
            now = s[now].ch[s[now].val < k];
            //树中没有这个值,就新建节点
            if(!now) {
                s[++tot].val = k;
                s[tot].cnt++;
                s[tot].fa = f;
                s[f].ch[s[f].val < k] = tot;
                maintain(tot);
                maintain(f);
                splay(tot);
                break;
            }
        }
    }
    //查找某个数 返回这个数是第几个
    inline int Find(int k) {
        int res = 0,now = rt;
        while(true) {
            //如果这个数比当前节点小,搜索左子树
            if(k<s[now].val) {
                now = s[now].ch[0];
            }else {
                //否则加上右子树的个数
                res += s[s[now].ch[0]].sz;
                //中序遍历,如果找到这个节点返回res+1
                if(k == s[now].val) {
                    splay(now);
                    return res + 1;
                }
                res += s[now].cnt;
                now = s[now].ch[1];
            }
        }
    }

    //查询小于x的最大的数的节点,就是找左儿子的右链
    inline int getPre() {
        int now = s[rt].ch[0];
        while (s[now].ch[1]) now = s[now].ch[1];
        return now;
    }

    //查询大于x的最小的数的节点,同理
    inline int getNxt() {
        int now = s[rt].ch[1];
        while (s[now].ch[0]) now = s[now].ch[0];
        return now;
    }

    //查询第k个数
    inline int getKth(int k) {
        int now = rt;
        while(true){
            if(s[now].ch[0] && k <= s[s[now].ch[0]].sz){
                now = s[now].ch[0];
            }else{
                k -= s[now].cnt + s[s[now].ch[0]].sz;
                if(k <= 0){
                    splay(now);
                    return s[now].val;
                }
                now=s[now].ch[1];
            }
        }
    }

    //删除结点
    inline void del(int k){
        Find(k);//先让该点成为根节点
        if(s[rt].cnt > 1) {//如果大于1,不需要删除节点
            s[rt].cnt--;
            maintain(rt);
            return;
        }
        //如果只有一个点
        if(!s[rt].ch[0] && !s[rt].ch[1]){
            Clear(rt);
            rt = 0;
            return;
        }
        //没有左儿子,让右儿子成为根节点
        if(!s[rt].ch[0]){
            int tmp = rt;
            rt = s[rt].ch[1];
            s[rt].fa=0;
            Clear(tmp);
            return;
        }
        //没有右儿子,让左儿子成为根节点
        if(!s[rt].ch[1]){
            int tmp = rt;
            rt = s[rt].ch[0];
            s[rt].fa = 0;
            Clear(tmp);
            return;
        }
        //有左右儿子,让前驱成为根节点
        int x = getPre() , now = rt;
        splay(x);
        s[s[now].ch[1]].fa = x;
        s[x].ch[1] = s[now].ch[1];
        Clear(now);
        maintain(rt);
    }
}st;
int main(){
    int n,opt,x;
    scanf("%d",&n);
    while(n--){
        scanf("%d%d",&opt,&x);
        if(opt == 1) st.ins(x);
        else if(opt == 2) st.del(x);
        else if(opt == 3) printf("%d\n",st.Find(x));
        else if(opt == 4) printf("%d\n",st.getKth(x));
        else if(opt == 5) {
            st.ins(x);
            printf("%d\n",s[st.getPre()].val);
            st.del(x);
        }
        else {
            st.ins(x);
            printf("%d\n",s[st.getNxt()].val);
            st.del(x);
        }
    }
    return 0;
}

oiwiki上的代码没有用结构体写起来比较快:

#include <cstdio>
const int N = 100005;
int rt, tot, fa[N], ch[N][2], val[N], cnt[N], sz[N];
struct Splay {
  void maintain(int x) { sz[x] = sz[ch[x][0]] + sz[ch[x][1]] + cnt[x]; }
  bool get(int x) { return x == ch[fa[x]][1]; }
  void clear(int x) {
    ch[x][0] = ch[x][1] = fa[x] = val[x] = sz[x] = cnt[x] = 0;
  }
  void rotate(int x) {
    int y = fa[x], z = fa[y], chk = get(x);
    ch[y][chk] = ch[x][chk ^ 1];
    fa[ch[x][chk ^ 1]] = y;
    ch[x][chk ^ 1] = y;
    fa[y] = x;
    fa[x] = z;
    if (z) ch[z][y == ch[z][1]] = x;
    maintain(x);
    maintain(y);
  }
  void splay(int x) {
    for (int f = fa[x]; f = fa[x], f; rotate(x))
      if (fa[f]) rotate(get(x) == get(f) ? f : x);
    rt = x;
  }
  void ins(int k) {
    if (!rt) {
      val[++tot] = k;
      cnt[tot]++;
      rt = tot;
      maintain(rt);
      return;
    }
    int cnr = rt, f = 0;
    while (1) {
      if (val[cnr] == k) {
        cnt[cnr]++;
        maintain(cnr);
        maintain(f);
        splay(cnr);
        break;
      }
      f = cnr;
      cnr = ch[cnr][val[cnr] < k];
      if (!cnr) {
        val[++tot] = k;
        cnt[tot]++;
        fa[tot] = f;
        ch[f][val[f] < k] = tot;
        maintain(tot);
        maintain(f);
        splay(tot);
        break;
      }
    }
  }
  int rk(int k) {
    int res = 0, cnr = rt;
    while (1) {
      if (k < val[cnr]) {
        cnr = ch[cnr][0];
      } else {
        res += sz[ch[cnr][0]];
        if (k == val[cnr]) {
          splay(cnr);
          return res + 1;
        }
        res += cnt[cnr];
        cnr = ch[cnr][1];
      }
    }
  }
  int kth(int k) {
    int cnr = rt;
    while (1) {
      if (ch[cnr][0] && k <= sz[ch[cnr][0]]) {
        cnr = ch[cnr][0];
      } else {
        k -= cnt[cnr] + sz[ch[cnr][0]];
        if (k <= 0) return val[cnr];
        cnr = ch[cnr][1];
      }
    }
  }
  int pre() {
    int cnr = ch[rt][0];
    while (ch[cnr][1]) cnr = ch[cnr][1];
    return cnr;
  }
  int nxt() {
    int cnr = ch[rt][1];
    while (ch[cnr][0]) cnr = ch[cnr][0];
    return cnr;
  }
  void del(int k) {
    rk(k);
    if (cnt[rt] > 1) {
      cnt[rt]--;
      maintain(rt);
      return;
    }
    if (!ch[rt][0] && !ch[rt][1]) {
      clear(rt);
      rt = 0;
      return;
    }
    if (!ch[rt][0]) {
      int cnr = rt;
      rt = ch[rt][1];
      fa[rt] = 0;
      clear(cnr);
      return;
    }
    if (!ch[rt][1]) {
      int cnr = rt;
      rt = ch[rt][0];
      fa[rt] = 0;
      clear(cnr);
      return;
    }
    int x = pre(), cnr = rt;
    splay(x);
    fa[ch[cnr][1]] = x;
    ch[x][1] = ch[cnr][1];
    clear(cnr);
    maintain(rt);
  }
} tree;

int main() {
  int n, opt, x;
  for (scanf("%d", &n); n; --n) {
    scanf("%d%d", &opt, &x);
    if (opt == 1)
      tree.ins(x);
    else if (opt == 2)
      tree.del(x);
    else if (opt == 3)
      printf("%d\n", tree.rk(x));
    else if (opt == 4)
      printf("%d\n", tree.kth(x));
    else if (opt == 5)
      tree.ins(x), printf("%d\n", val[tree.pre()]), tree.del(x);
    else
      tree.ins(x), printf("%d\n", val[tree.nxt()]), tree.del(x);
  }
  return 0;
}

来源

[1] https://oi-wiki.org/ds/splay/

posted @ 2020-03-04 17:50  house_cat  阅读(250)  评论(0编辑  收藏  举报