二叉平衡树

您需要写一个数据结构来维护一些数,其中需要提供以下操作:

  1. 插入数值 \(x\).

  2. 删除数值 \(x\) (其中若有多个相同的数,则应只删除一个).

  3. 查询数值 \(x\) 的排名(其中若有多个相同的数,则应输出最小的排名).

  4. 查询排名为 \(x\) 的数值.

  5. 求数值 \(x\) 的前驱(前驱定义为小于 \(x\) 的最大的数).

  6. 求数值 \(x\) 的后继(后继定义为大于 \(x\) 的最小的数).

其中共有 \(n\ \left(1 \le n \le 100000\right)\) 次操作,所有数均在 \(-10^7\)\(10^7\) 以内。对于每个操作 \(3,4,5,6\) 输出一行,包括一个数,即询问的答案。

注意 : 数据保证查询的结果一定存在。

Solution

我们使用一种名为 Treap 的平衡树。Treap = BST + heap。对于 heap,没什么好说的,我们来说明一下 BSTBST (Binary Search Tree) 就是一个 中序遍历单调递增 的树。

结构体介绍

结构体内存储六个变量,\(l, r, w, value, cnt, size\),分别表示左子节点,右子节点,点权,排序优先级,该点权出现次数,子树大小。

struct node
{
    int l, r, w, value;
    int cnt, sizes;
}tree[N];

基础操作

  • \(\operatorname{void\ pushup(int\ w)}\) 上传节点信息
void pushup(int p)
{
    tree[p].size = tree[tree[p].l].size + tree[tree[p].r].size;
    tree[p].size += tree[p].cnt;
}
  • \(\operatorname{void\ get\_node(int\ w)}\) 新建点权为 \(w\) 的节点
int get_node(int w)
{
    tree[++ idx].w = w;
    tree[idx].value = rand();
    tree[idx].cnt = tree[idx].size = 1;
    return idx;
}
  • \(\operatorname{void\ build()}\) 建树
void build()
{
    get_node(-inf), get_node(inf);
    root = 1, tree[1].r = 2;
    pushup(root);
    if (tree[1].value < tree[2].value) zag(root);
}
  • \(\operatorname{void\ zig(int\ \&p)}\) 右旋
void zig(int &p)    // 右旋
{
    int q = tree[p].l;
    tree[p].l = tree[q].r, tree[q].r = p, p = q;
    pushup(tree[p].r), pushup(p);
}
  • \(\operatorname{void\ zag(int\ \&p)}\) 左旋
void zag(int &p)    // 左旋
{
    int q = tree[p].r;
    tree[p].r = tree[q].l, tree[q].l = p, p = q;
    pushup(tree[p].l), pushup(p);
}

插入

void insert(int &p, int w)
{
    if (!p) p = get_node(w);
    else if (tree[p].w == w) tree[p].cnt ++;
    else if (tree[p].w > w)
    {
        insert(tree[p].l, w);
        if (tree[tree[p].l].value > tree[p].value) zig(p);
    }
    else
    {
        insert(tree[p].r, w);
        if (tree[tree[p].r].value > tree[p].value) zag(p);
    }
    pushup(p);
}

删除

void remove(int &p, int w)
{
    if (!p) return ;
    if (tree[p].w == w)
    {
        if (tree[p].cnt > 1) tree[p].cnt --;
        else if (tree[p].l || tree[p].r)
        {
            if (!tree[p].r || tree[tree[p].l].value > tree[tree[p].r].value)
            {
                zig(p);
                remove(tree[p].r, w);
            }
            else
            {
                zag(p);
                remove(tree[p].l, w);
            }
        }
        else p = 0;
    }
    else if (tree[p].w > w) remove(tree[p].l, w);
    else remove(tree[p].r, w);
    pushup(p);
}

本题要求的四种查询

int get_rank_by_w(int p, int w) // 根据权值找排名
{
    if (!p) return -1; // 本题已经保证无解,不会出现本情况
    if (tree[p].w == w) return tree[tree[p].l].sizes + 1;
    if (tree[p].w > w) return get_rank_by_w(tree[p].l, w);
    return tree[tree[p].l].sizes + tree[p].cnt + get_rank_by_w(tree[p].r, w);
}

int get_w_by_rank(int p, int rank) // 根据排名找权值
{
    if (!p) return inf;
    if (tree[tree[p].l].sizes >= rank) return get_w_by_rank(tree[p].l, rank);
    if (tree[tree[p].l].sizes + tree[p].cnt >= rank) return tree[p].w;
    return get_w_by_rank(tree[p].r, rank - tree[tree[p].l].sizes - tree[p].cnt);
}

int get_prev(int p, int w) // 根据权值找前驱
{
    if (!p) return -inf;
    if (tree[p].w >= w) return get_prev(tree[p].l, w);
    return max(tree[p].w, get_prev(tree[p].r, w));
}

int get_next(int p, int w) // 根据权值找后驱
{
    if (!p) return inf;
    if (tree[p].w <= w) return get_next(tree[p].r, w);
    return min(tree[p].w, get_next(tree[p].l, w));
}

完整代码如下:

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <ctime>
#include <cstdlib>
using namespace std;
const int N = 1e5 + 10, inf = 1e8;
int n, root, idx;

struct node
{
    int l, r, w, value;
    int cnt, sizes;
}tree[N];

int get_node(int w)
{
    tree[++ idx].w = w;
    tree[idx].value = rand();
    tree[idx].cnt = tree[idx].sizes = 1;
    return idx;
}

void pushup(int p)
{
    tree[p].sizes = tree[tree[p].l].sizes + tree[tree[p].r].sizes + tree[p].cnt;
}

void zig(int &p) // 右旋
{
    int q = tree[p].l;
    tree[p].l = tree[q].r;
    tree[q].r = p;
    p = q;
    pushup(tree[p].r), pushup(p);
}

void zag(int &p)
{
    int q = tree[p].r;
    tree[p].r = tree[q].l;
    tree[q].l = p;
    p = q;
    pushup(tree[p].l), pushup(p);
}

void build()
{
    get_node(-inf), get_node(inf);
    root = 1, tree[1].r = 2;
    pushup(root);
    if (tree[1].value < tree[2].value) zag(root);
}

void insert(int &p, int w)
{
    if (!p) p = get_node(w);
    else if (tree[p].w == w) tree[p].cnt ++;
    else if (tree[p].w > w)
    {
        insert(tree[p].l, w);
        if (tree[tree[p].l].value > tree[p].value) zig(p);
    }
    else
    {
        insert(tree[p].r, w);
        if (tree[tree[p].r].value > tree[p].value) zag(p);
    }
    pushup(p);
}

void remove(int &p, int w)
{
    if (!p) return ;
    if (tree[p].w == w)
    {
        if (tree[p].cnt > 1) tree[p].cnt --;
        else if (tree[p].l || tree[p].r)
        {
            if (!tree[p].r || tree[tree[p].l].value > tree[tree[p].r].value)
            {
                zig(p);
                remove(tree[p].r, w);
            }
            else
            {
                zag(p);
                remove(tree[p].l, w);
            }
        }
        else p = 0;
    }
    else if (tree[p].w > w) remove(tree[p].l, w);
    else remove(tree[p].r, w);
    pushup(p);
}

int get_rank_by_w(int p, int w)
{
    if (!p) return -1; // 本题已经保证无解,不会出现本情况
    if (tree[p].w == w) return tree[tree[p].l].sizes + 1;
    if (tree[p].w > w) return get_rank_by_w(tree[p].l, w);
    return tree[tree[p].l].sizes + tree[p].cnt + get_rank_by_w(tree[p].r, w);
}

int get_w_by_rank(int p, int rank)
{
    if (!p) return inf;
    if (tree[tree[p].l].sizes >= rank) return get_w_by_rank(tree[p].l, rank);
    if (tree[tree[p].l].sizes + tree[p].cnt >= rank) return tree[p].w;
    return get_w_by_rank(tree[p].r, rank - tree[tree[p].l].sizes - tree[p].cnt);
}

int get_prev(int p, int w)
{
    if (!p) return -inf;
    if (tree[p].w >= w) return get_prev(tree[p].l, w);
    return max(tree[p].w, get_prev(tree[p].r, w));
}

int get_next(int p, int w)
{
    if (!p) return inf;
    if (tree[p].w <= w) return get_next(tree[p].r, w);
    return min(tree[p].w, get_next(tree[p].l, w));
}

int main()
{
#ifdef FIO
    freopen("E:/Code/In.in", "r", stdin);
    freopen("E:/Code/Out.out", "w", stdout);
#endif
    srand(unsigned(time(0)));
    build();

    scanf("%d", &n);
    while (n --)
    {
        int op, x;
        scanf("%d%d", &op, &x);
        if (op == 1) insert(root, x);
        else if (op == 2) remove(root, x);
        else if (op == 3) printf("%d\n", get_rank_by_w(root, x) - 1);
        else if (op == 4) printf("%d\n", get_w_by_rank(root, x + 1));
        else if (op == 5) printf("%d\n", get_prev(root, x));
        else printf("%d\n", get_next(root, x));
    }

    return 0;
}
posted @ 2022-01-08 21:37  幼稚园茶茶子  阅读(20)  评论(0编辑  收藏  举报