【平衡树一】Splay树

Splay树

节点维护信息

rt tot fa[i] chi val[i] cnt[i] size[i]
根节点编号 节点个数 父亲 左右儿子编号 节点权值 权值出现次数 子树大小

操作

基本操作

maintain(x):在改变节点位置后,将节点x的size更新;

get(x):判断节点x是父亲节点的左孩子还是右孩子;

clear(x):销毁节点x;

void maintain(int x) {size[x] = size[ch[x]][0]] + size[ch[x][1]] + cnt[x];}  
bool get(int x) {return x == ch[fa[x]][1];}

void clear(int x) 
{
    ch[x][0] = ch[x][1] = fa[x] = val[x] = size[x] = cnt[x] = 0;
}

旋转操作

为了使得Splay保持平衡而进行旋转操作,旋转的本质就是将某个节点上移一个位置。
旋转需要保证:
1.整颗树Splay中序遍历不变(不破坏二叉查找树的性质);
2.受影响节点维护信息依然正确有效;
3.root必须指向旋转后的根节点;

Splay中分两种旋转:左旋和右旋

                1                       2     
               / \                     / \     
              2   3    --->右旋        4  1     
             / \       <---左旋          / \     
            4   5                       5  3       

具体分析旋转步骤:
假设需要旋转的节点为x,其父亲为y,以右旋为例;
1.将y的左孩子指向x的右孩子,且x的右孩子的父亲指向y;
ch[y][0] = ch[x][1];
fa[ch[x][1]] = y;
2.将x的右孩子指向y,且y的父亲指向x;

    ch[x][chk^1] = y;
    fa[y] = x;

3.如果原来y还有父亲z,那么把z的某个儿子(原来y所在的位置)指向x,且x的父亲指向z;

    fa[x] = z;
    if (z)  {
        ch[z][y == ch[z][1]] = x;
    }
void rotate(int x) 
{
    int y = fa[x], z = fa[y], chk = get(x);
    //step1
    ch[y][chk] = ch[x][chk^1];
    fa[ch[x][chk^1]] = y;
    //step2
    ch[x][chk^1] = y;
    fa[y] = x;
    //step3
    fa[x] = z;
    if (z) ch[z][y == ch[z][1]] = x;
    maintain(y);
    maintain(x);
}

Splay操作

splay规定:每访问一个节点后都要强制将其旋转到根节点,此时旋转操作具体分6种情况讨论(其中x为需要旋转到根的节点)。

  1.    y       
      /          
     x          
    
  2.    y      
        \    
         x    
    
  3.           z     
             /     
            y    
           /     
          x          
    
  4.      z       
          \    
           y   
            \    
             x    
    
  5.      z       
          \    
           y   
          /     
         x
    
  6.           z     
             /     
            y    
             \    
              x         
    

1.如果x的父亲是根节点,直接旋转x左旋或右旋;(1/2)
2.如果x的父亲不是根节点,且x和父亲的儿子类型相同,首先将其父亲左旋或右旋,然后再将x旋转;(3/4)
3.如果x的父亲不是根节点,且x和父亲的儿子类型不相同,将x左旋在右旋,或者右旋在左旋;(5/6)

void splay(int x)
{
    for (int f = fa[x]; f = fa[x], f; rotate(x)) {
        if (fa[f]) rotate(get(x) == get(f) ? f : x);
    }
    rt = x;
}

插入操作

插入操作是一个比较复杂的过程,具体步骤如下(插入值为k):
1.如果树空,直接插入退出;
2.如果当前节点权值等于k则增加当前节点的大小并更新节点和父亲的信息,将当前节点进行Splay操作;
3.否则按照二叉树性质往下找,找到空节点插入即可,最后Splay一下;

void Insert(int k)
{
    if (!rt)  {
        val[++tot] = k;
        cnt[tot]++;
        rt = tot;
        maintain(rt);
        return;
    }

    int cnr = rt, f = 0;
    while (1) {
        if (val[cnr] == k) {
            cnt[cnr]++;
            maintain(cnr);
            maintain(f);
            splay(cnr);
            break;
        }

        f = cnr;
        cnr = ch[cnr][val[cnr] < k];
        if (!cnr) {
            val[++tot] = k;
            cnt[tot]++;
            fa[tot] = f;
            ch[f][val[f] < k] = tot;
            maintain(tot);
            maintain(f);
            splay(tot);
            break;
        }
    }
}

查询x的排名

根据二叉树的定义和性质,显然可以按照以下步骤查询x的排名:
1.如果x比当前节点的权值小,向其左子树查找;
2.如果x比当前节点的权值大,将其加上左子树和当前节点cnt大小,向右子树查找;
3.如果x与当前节点的权值相同,将答案加1并返回;

int rank(int k)
{
    int res = 0, cnr = rt;
    while (1) {
        if (k < val[cnr]) {
            cnr = ch[cnr][0];
        }else {
            res += size[ch[cnr][0]];
            if (k == val[cnr]) {
                splay(cnr);
                return res+1;
            }

            res += cnt[cnr];
            cnr = ch[cnr][1];
        }
    }
}

查询排名x的数

设k为剩余排名,步骤如下:
1.如果左子树非空且剩余排名k不大于左子树的size,那么向左子树查找;
2.否则将k减去左子树和根的大小,如果此时k的值小于等于0,则返回根节点的权值,否则继续向右子树查找;

inr kth(int k)
{
    int cnr = rt;
    while(1) {
        if (ch[cnr][0] && size[ch[cnr][0]] >= k) {
            cnr = ch[cnr][0];
        }else {
            k -= cnt[cnr] + size(ch[cnr][0]);
            if (k <= 0) {
                splay(cnr);
                return val[cnr];
            }
            cnr = ch[cnr][1];
        }

    }
}

查询前驱

前驱定义为小于x的最大数,那么查询前驱就可以变为:将x插入(此时x已经在根节点上),前驱即为x的左子树中最右边的节点,最后将x删除即可;

int pre(int x) {
    Insert(x);
    int cnr = ch[rt][0];
    while (ch[cnr][1]) cnr = ch[cnr][1];
    splay(cnr);

    delete(x);
    return cnr;
}

查询后继

后继定义为大于x的最小的数,查询和前驱类似:x的右子树中的最左边节点;

int next()
{
    int cnr = ch[rt][1];
    while (ch[cnr][0]) cnr = ch[cnr][0];
    splay(cnr);
    return cnr;
}

合并两棵树

合并两颗Splay树,设两颗树的根节点分别为x和y,先做着假定,要求x树种最大值不大于y中的最小值。合并操作如下:
1.如果x和y其中之一或两者都为空树,直接返回不为空的那一颗树的根节点或空树;
2.否则将x树中的最大值Splay到根节点,然后将他的右子树设置为y并更新节点信息,然后返回这个节点;

int merge(int x, int y)
{
    if (!x) {
        return y;
    }

    if (!y) {
        return x;
    }

    //找到x树种最大值;
    int cnr = ch[x][1];
    while (ch[cnr][1]) cnr = ch[cnr][1];
    splay(cnr);

    fa[y] = cnr;
    ch[cnr][1] = y;
    return cnr;
}

删除操作

删除操作具体如下:
1.首先将x旋转到根节点;
2.如果cnt[x] > 1,将cnt[x]--后退出;
3.否则,合并它的左右两颗子树即可;

void del(int k)
{
    rank(k);    //查询旋转到根节点;

    if (cnt[k] > 1) {
        cnt[k]--;
        maintain(rt);
        return;
    }

    //cnt[k]==1;
    if (!ch[rt][0] && !ch[rt][1]) {
        clear(rt);
        rt = 0;
        return ;
    }

    if (!ch[rt][0]) {
        int cnr = rt;
        rt = ch[rt][1];
        fa[rt] = 0;
        clear(cnr);
        return ;
    }

    if (!ch[rt][1]) {
        int cnr = rt;
        rt = ch[rt][0];
        fa[rt] = 0;
        clear(cnt);
        return;
    }

    int cnr = rt, x = pre();
    splay(x);
    fa[ch[cnr][1]] = x;
    ch[x][1] = ch[cnr][1];
    clear(cnr);
    maintain(cnr);
    return;
}

完整代码

#include <cstdio>
const int N = 100005;
int rt, tot, fa[N], ch[N][2], val[N], cnt[N], sz[N];
struct Splay {
  void maintain(int x) { sz[x] = sz[ch[x][0]] + sz[ch[x][1]] + cnt[x]; }
  bool get(int x) { return x == ch[fa[x]][1]; }
  void clear(int x) {
    ch[x][0] = ch[x][1] = fa[x] = val[x] = sz[x] = cnt[x] = 0;
  }
  void rotate(int x) {
    int y = fa[x], z = fa[y], chk = get(x);
    ch[y][chk] = ch[x][chk ^ 1];
    fa[ch[x][chk ^ 1]] = y;
    ch[x][chk ^ 1] = y;
    fa[y] = x;
    fa[x] = z;
    if (z) ch[z][y == ch[z][1]] = x;
    maintain(x);
    maintain(y);
  }
  void splay(int x) {
    for (int f = fa[x]; f = fa[x], f; rotate(x))
      if (fa[f]) rotate(get(x) == get(f) ? f : x);
    rt = x;
  }
  void ins(int k) {
    if (!rt) {
      val[++tot] = k;
      cnt[tot]++;
      rt = tot;
      maintain(rt);
      return;
    }
    int cnr = rt, f = 0;
    while (1) {
      if (val[cnr] == k) {
        cnt[cnr]++;
        maintain(cnr);
        maintain(f);
        splay(cnr);
        break;
      }
      f = cnr;
      cnr = ch[cnr][val[cnr] < k];
      if (!cnr) {
        val[++tot] = k;
        cnt[tot]++;
        fa[tot] = f;
        ch[f][val[f] < k] = tot;
        maintain(tot);
        maintain(f);
        splay(tot);
        break;
      }
    }
  }
  int rk(int k) {
    int res = 0, cnr = rt;
    while (1) {
      if (k < val[cnr]) {
        cnr = ch[cnr][0];
      } else {
        res += sz[ch[cnr][0]];
        if (k == val[cnr]) {
          splay(cnr);
          return res + 1;
        }
        res += cnt[cnr];
        cnr = ch[cnr][1];
      }
    }
  }
  int kth(int k) {
    int cnr = rt;
    while (1) {
      if (ch[cnr][0] && k <= sz[ch[cnr][0]]) {
        cnr = ch[cnr][0];
      } else {
        k -= cnt[cnr] + sz[ch[cnr][0]];
        if (k <= 0) {
          splay(cnr);
          return val[cnr];
        }
        cnr = ch[cnr][1];
      }
    }
  }
  int pre() {
    int cnr = ch[rt][0];
    while (ch[cnr][1]) cnr = ch[cnr][1];
    splay(cnr);
    return cnr;
  }
  int nxt() {
    int cnr = ch[rt][1];
    while (ch[cnr][0]) cnr = ch[cnr][0];
    splay(cnr);
    return cnr;
  }
  void del(int k) {
    rk(k);
    if (cnt[rt] > 1) {
      cnt[rt]--;
      maintain(rt);
      return;
    }
    if (!ch[rt][0] && !ch[rt][1]) {
      clear(rt);
      rt = 0;
      return;
    }
    if (!ch[rt][0]) {
      int cnr = rt;
      rt = ch[rt][1];
      fa[rt] = 0;
      clear(cnr);
      return;
    }
    if (!ch[rt][1]) {
      int cnr = rt;
      rt = ch[rt][0];
      fa[rt] = 0;
      clear(cnr);
      return;
    }
    int cnr = rt;
    int x = pre();
    splay(x);
    fa[ch[cnr][1]] = x;
    ch[x][1] = ch[cnr][1];
    clear(cnr);
    maintain(rt);
  }
} tree;

int main() {
  int n, opt, x;
  for (scanf("%d", &n); n; --n) {
    scanf("%d%d", &opt, &x);
    if (opt == 1)
      tree.ins(x);
    else if (opt == 2)
      tree.del(x);
    else if (opt == 3)
      printf("%d\n", tree.rk(x));
    else if (opt == 4)
      printf("%d\n", tree.kth(x));
    else if (opt == 5)
      tree.ins(x), printf("%d\n", val[tree.pre()]), tree.del(x);
    else
      tree.ins(x), printf("%d\n", val[tree.nxt()]), tree.del(x);
  }
  return 0;
}

//参考网址:https://oi-wiki.org/ds/splay/#_8
为了加深理解,手抄一遍;

posted @   zhanghanLeo  阅读(183)  评论(0编辑  收藏  举报
编辑推荐:
· 如何编写易于单元测试的代码
· 10年+ .NET Coder 心语,封装的思维:从隐藏、稳定开始理解其本质意义
· .NET Core 中如何实现缓存的预热?
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
阅读排行:
· 地球OL攻略 —— 某应届生求职总结
· 周边上新:园子的第一款马克杯温暖上架
· Open-Sora 2.0 重磅开源!
· 提示词工程——AI应用必不可少的技术
· .NET周刊【3月第1期 2025-03-02】
点击右上角即可分享
微信分享提示