[lnsyoj118/luoguP3369]普通平衡树

题意

维护一个数据结构,要求支持插入,删除,根据排名查数,根据数查排名,查询前驱,查询后继\(6\)个操作

sol

考虑到后四个查询的操作,会发现使用二叉搜索树(BST)完全可以实现
为了完成这四个操作,需要在每个节点记录\(3\)个值:

  1. \(key\) 表示当前节点的数
  2. \(cnt\) 表示当前节点的数的个数(为了防止出现同一数字出现多次)
  3. \(size\) 表示当前子树的数的个数(为了方便查询排名)

根据排名查数

当处于节点\(u\)时,设当前需要查询的排名为\(rank\),如果此时节点\(u\)为空节点,说明不存在该数,否则分情况讨论:

  1. 如果\(u.lson.size \ge rank\),说明此时要查询的数一定位于\(u\)的左子树,因此答案为左子树中排名为\(rank\)的数
  2. 如果\(u.lson.size + u.cnt \ge rank\),说明此时要查询的数为\(u.key\),因此答案就为\(u.key\)
  3. 前两条均不满足,则说明此时要查询的数一定位于\(u\)的右子树,又由于需要去除掉左子树和\(u\)的所有数,因此答案为右子树中排名为\(rank - u.lson.size - u.cnt\)的数

代码

int get_key(int u, int rank){
    if (!u) return INF;
    if (rank <= tr[tr[u].l].size) return get_key(tr[u].l, rank);
    if (rank <= tr[tr[u].l].size + tr[u].cnt) return tr[u].key;
    return get_key(tr[u].r, rank - tr[tr[u].l].size - tr[u].cnt);
}

根据数查排名

当处于节点\(u\)时,设当前需要查询的数为\(x\),如果此时节点\(u\)为空节点,说明不存在该数,否则分情况讨论:

  1. 如果\(x < u.key\),说明此时要查询的数一定位于\(u\)的左子树,因此答案为左子树中数\(x\)的排名
  2. 如果\(x = u.key\),说明此时要查询的数为\(u.key\),因此答案为\(u.lson.size + 1\)
  3. 前两条均不满足,则说明此时要查询的数一定位于\(u\)的右子树,又由于需要加上左子树和\(u\)的所有数,因此答案为右子树中数\(x\)的排名\(+u.lson.size + u.cnt\)

代码

int get_rank(int u, int key){
    if (!u) return 0;
    if (key < tr[u].key) return get_rank(tr[u].l, key);
    if (key == tr[u].key) return tr[tr[u].l].size + 1;
    return tr[tr[u].l].size + tr[u].cnt + get_rank(tr[u].r, key);
}

需要注意的是,部分时候为了方便,我们会在BST中加入两个哨兵节点\(-\infty\)\(+\infty\),此时由于\(-\infty\)的存在,根据排名查数时的\(rank\)需要\(+1\),而根据数查排名时的查得的答案需要\(-1\)

查询前驱

当处于节点\(u\)时,设当前需要查询的数为\(x\),如果此时节点\(u\)为空节点,说明不存在该数,否则分情况讨论:

  1. \(x \le u.key\),说明此时要查询的数一定不位于\(u\)的右子树,因此答案为左子树中数\(x\)的前驱
  2. \(x > u.key\),说明此时要查询的数可能为\(u.key\),也可能位于\(u\)的右子树,因此答案为右子树中数\(x\)的前驱与\(u.key\)中的最大值

代码

int get_prev(int u, int key){
    if (!u) return -INF;
    if (tr[u].key >= key) return get_prev(tr[u].l, key);
    return max(tr[u].key, get_prev(tr[u].r, key));
}

查询后继

当处于节点\(u\)时,设当前需要查询的数为\(x\),如果此时节点\(u\)为空节点,说明不存在该数,否则分情况讨论:

  1. \(x \ge u.key\),说明此时要查询的数一定不位于\(u\)的左子树,因此答案为右子树中数\(x\)的前驱
  2. \(x < u.key\),说明此时要查询的数可能为\(u.key\),也可能位于\(u\)的左子树,因此答案为左子树中数\(x\)的前驱与\(u.key\)中的最小值

代码

int get_next(int u, int key){
    if (!u) return INF;
    if (tr[u].key <= key) return get_next(tr[u].r, key);
    return min(tr[u].key, get_next(tr[u].l, key));
}

这样一来,这四个查询操作及两个修改操作的复杂度为\(O(h)\)\(h\)为BST高度。在随机数据下,\(h\)趋向于\(\log n\),但由于BST容易被卡的优秀性质,只需递增/递减数据就可以将BST卡成一条链,从而使\(h=n\),因此,我们需要一些手段来使BST的\(h\)无论何时都接近于\(\log n\),平衡树应运而生

旋转

BST有一条很好的性质:容易被卡中序遍历是单调递增的,反过来也成立,如果我们可以通过一些操作,使中序遍历不变,那么这棵树仍是本质相同的BST,而这个能够使中序遍历不变的操作即为旋转,旋转是几乎所有平衡树都需要使用的操作(部分除外,如FHQ-Treap
image
两图中序遍历都为\(A,Q,B,P,C\)
在执行\(zig\)操作时,需要进行三次改变:\(p.lson \to q.rson(B), q.rson \to p, p \to q\)
同理,在执行\(zag\)操作时,也需要进行三次改变:\(q.rson \to p.lson(B), p.lson \to q, q \to p\)
代码

void zig(int &u){
    int q = tr[u].l;
    tr[u].l = tr[q].r, tr[q].r = u, u = q;
}

void zag(int &u){
    int q = tr[u].r;
    tr[u].r = tr[q].l, tr[q].l = u, u = q;
}

需要注意的是,这里的\(u\)指代的是根节点或某个节点的子节点,当执行\(zig\)\(zag\)时,所对应的节点也要改变,因此需要在函数中传递引用。旋转操作可以视为是BST上三条边所指节点的交换操作

Treap

Treap是OI中较常用的一种平衡树
Treap是Tree和Heap的结合体,它的原理非常简单粗暴:既然BST在随机数据下趋于\(\log n\),那么我们就把所有数据打乱顺序再插入就好了。显然,在\(99.99\%\)的情况之下,这种方法都是有效的。不过因为大多数平衡树解决的问题都是在线问题,因此我们无法简单地将数据打乱。
Treap给出的解决方案是这样的:对于每一个节点,在插入时赋予它一个随机权值\(val\),由于可以通过\(zig\)\(zag\)操作将BST的任一一对父子节点交换而不改变BST的本质,因此我们可以参考二叉堆,插入到对应位置后再向上调整,直到BST中的\(val\)仍然满足二叉堆的性质
对于插入操作,我们先将一个节点插入BST中,然后从下往上判断它是否需要调序;而对于删除操作,我们在BST中找到该节点后,为了方便操作,我们将该节点先调整到叶子结点上,再进行删除。具体代码见下:

void insert(int &u, int key){
    if (!u) u = create(key); // 没有该节点的话,就创建一个新节点
    else if (key == tr[u].key) tr[u].cnt ++ ; // 否则直接在节点上添加标记
    else if (key < tr[u].key){
        insert(tr[u].l, key);
        if (tr[tr[u].l].val < tr[u].val) zig(u); // 向上调序
    }
    else {
        insert(tr[u].r, key);
        if (tr[tr[u].r].val < tr[u].val) zag(u); // 向上调序
    }
}

void erase(int &u, int key){
    if (!u) return ; // 没有该节点的话,无需处理
    else if (key == tr[u].key){
        if (tr[u].cnt > 1) tr[u].cnt -- ; // 如果存在多个标记,直接删除标记
        else if (tr[u].l || tr[u].r){
            if (!tr[u].r || tr[tr[u].l].val > tr[tr[u].r].val){
                zig(u); // 先向下调序
                erase(tr[u].r, key);
            }
            else{
                zag(u); // 先向下调序
                erase(tr[u].l, key);
            }
        }
        else u = 0; // 调到叶子节点后直接删除
    }
    else if (key < tr[u].key) erase(tr[u].l, key);
    else erase(tr[u].r, key);
}

需要注意的是,本题的\(size\)是会在旋转、插入、删除操作中随时改变的,类比线段树,我们还需要一个方法来根据子结点的数据反推节点的\(size\),即PUSHUP
代码:

void pushup(int u){
    tr[u].size = tr[tr[u].l].size + tr[tr[u].r].size + tr[u].cnt;
}

这样的话,我们就通过精巧的操作使BST基本平衡,平均时间复杂度也随之下降为\(O(n \log n)\),不过值得注意的是,其最坏复杂度仍为\(O(n^2)\),只是如果真的卡出来了,概率堪比十连十金

代码

#include <iostream>
#include <cstring>
#include <algorithm>
#include <cstdlib>

using namespace std;

const int N = 100005, INF = 0x3f3f3f3f;

struct Node{
    int l, r;
    int key, val;
    int cnt, size;
}tr[N];

int root, idx;
int n;

int create(int key){
    tr[ ++ idx].key = key;
    tr[idx].val = rand();
    tr[idx].cnt = tr[idx].size = 1;
    return idx;
}

void pushup(int u){
    tr[u].size = tr[tr[u].l].size + tr[tr[u].r].size + tr[u].cnt;
}

void zig(int &u){
    int q = tr[u].l;
    tr[u].l = tr[q].r, tr[q].r = u, u = q;
    pushup(tr[u].r);
}

void zag(int &u){
    int q = tr[u].r;
    tr[u].r = tr[q].l, tr[q].l = u, u = q;
    pushup(tr[u].l);
}

void build(){
    create(-INF), create(INF);
    root = 1, tr[1].r = 2;
    pushup(root);
}

void insert(int &u, int key){
    if (!u) u = create(key);
    else if (key == tr[u].key) tr[u].cnt ++ ;
    else if (key < tr[u].key){
        insert(tr[u].l, key);
        if (tr[tr[u].l].val < tr[u].val) zig(u);
    }
    else {
        insert(tr[u].r, key);
        if (tr[tr[u].r].val < tr[u].val) zag(u);
    }
    pushup(u);
}

void erase(int &u, int key){
    if (!u) return ;
    else if (key == tr[u].key){
        if (tr[u].cnt > 1) tr[u].cnt -- ;
        else if (tr[u].l || tr[u].r){
            if (!tr[u].r || tr[tr[u].l].val > tr[tr[u].r].val){
                zig(u);
                erase(tr[u].r, key);
            }
            else{
                zag(u);
                erase(tr[u].l, key);
            }
        }
        else u = 0;
    }
    else if (key < tr[u].key) erase(tr[u].l, key);
    else erase(tr[u].r, key);
    
    pushup(u);
}

int get_rank(int u, int key){
    if (!u) return 0;
    if (key < tr[u].key) return get_rank(tr[u].l, key);
    if (key == tr[u].key) return tr[tr[u].l].size + 1;
    return tr[tr[u].l].size + tr[u].cnt + get_rank(tr[u].r, key);
}

int get_key(int u, int rank){
    if (!u) return INF;
    if (rank <= tr[tr[u].l].size) return get_key(tr[u].l, rank);
    if (rank <= tr[tr[u].l].size + tr[u].cnt) return tr[u].key;
    return get_key(tr[u].r, rank - tr[tr[u].l].size - tr[u].cnt);
}

int get_prev(int u, int key){
    if (!u) return -INF;
    if (tr[u].key >= key) return get_prev(tr[u].l, key);
    return max(tr[u].key, get_prev(tr[u].r, key));
}

int get_next(int u, int key){
    if (!u) return INF;
    if (tr[u].key <= key) return get_next(tr[u].r, key);
    return min(tr[u].key, get_next(tr[u].l, key));
}

int main(){
    scanf("%d", &n);
    build();
    while (n -- ){
        int op, x;
        scanf("%d%d", &op, &x);
        switch(op){
            case 1: insert(root, x); break;
            case 2: erase(root, x); break;
            case 3: printf("%d\n", get_rank(root, x) - 1); break;
            case 4: printf("%d\n", get_key(root, x + 1)); break;
            case 5: printf("%d\n", get_prev(root, x)); break;
            case 6: printf("%d\n", get_next(root, x)); break;
            default: break;
        }
    }
    return 0;
}

蒟蒻犯的若至错误

  • \(zig\)\(zag\)的时候没有PUSHUP导致整颗BST的\(size\)都计算错误
posted @ 2024-06-13 21:34  是一只小蒟蒻呀  阅读(23)  评论(0编辑  收藏  举报