二叉排序树基本操作

题目描述

编写一棵二叉排序树,来支持以下 \(6\) 种操作:

  1. 插入 \(x\)
  2. 删除 \(x\) 数(若有多个相同的数,因只删除一个;如果 \(x\) 不存在则不需要删除)
  3. 查询 \(x\) 数的排名(排名定义为比当前数小的数的个数 \(+1\) ;如果 \(x\) 不存在则输出 \(-1\))
  4. 查询排名为 \(x\) 的数(如果 \(x\) 大于树中元素个数,则输出 \(-1\)
  5. \(x\) 的前驱(前驱定义为小于 \(x\),且最大的数;如果没有输出 \(-1\) )
  6. \(x\) 的后继(后继定义为大于 \(x\),且最小的数;如果没有输出 \(-1\) )

输入格式

第一行为 \(n\)\(1 \le n \le 10000\)),表示操作的个数,下面 \(n\) 行每行有两个数 \(\text{opt}\)\(x\)\(\text{opt}\) 表示操作的序号( \(1 \leq \text{opt} \leq 6\) )

输出格式

对于操作 \(3,4,5,6\) 每行输出一个数,表示对应答案

样例输入

10
1 3
1 7
1 15
1 12
3 7
2 7
3 7
4 1
5 8
6 8

样例输出

2
-1
3
3
12

实现代码如下:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 100010;
int lson[maxn], rson[maxn], val[maxn], sz, cnt, tot[maxn];
void Insert(int num) {
    if (cnt == 0) {  // 如果树为空,则直接插入根节点
        cnt ++;
        val[++sz] = num;
        tot[sz] = 1;
        return;
    }
    // 判断num是否存在
    int x = 1;
    while (true) {
        if (num == val[x])  // 存在,直接返回
            return;
        else if (num < val[x]) {
            if (lson[x]) x = lson[x];
            else break;
        }
        else {
            if (rson[x]) x = rson[x];
            else break;
        }
    }
    // 插入num
    cnt ++;
    val[++sz] = num;
    tot[sz] = 1;
    x = 1;
    while (true) {
        tot[x] ++;
        if (num < val[x]) {
            if (lson[x]) x = lson[x];
            else {
                lson[x] = sz;
                break;
            }
        }
        else {
            if (rson[x]) x = rson[x];
            else {
                rson[x] = sz;
                break;
            }
        }
    }
}
void Delete(int num) {
    if (sz == 0) return;
    if (cnt == 1) {
        if (val[1] != num) return;
        cnt --;
        lson[1] = rson[1] = 0;
        return;
    }
    int x = 1, p = 0, y, q;
    while (true) {
        if (num == val[x]) break;
        else if (num < val[x]) {
            p = x;
            if (!lson[x]) return;
            x = lson[x];
        }
        else {
            p = x;
            if (!rson[x]) return;
            x = rson[x];
        }
    }
    cnt --;
    x = 1; p = 0;
    while (true) {
        tot[x] --;
        if (num == val[x]) break;
        else if (num < val[x]) {
            p = x;
            x = lson[x];
        }
        else {
            p = x;
            x = rson[x];
        }
    }
    if (!lson[x] && !rson[x]) { // 要删除的x是叶子节点
        if (p) {
            if (lson[p] == x) lson[p] = 0;
            else rson[p] = 0;
        }
    }
    else if (lson[x]) {
        y = lson[x], q = x;
        while (rson[y]) {
            tot[y] --;
            q = y;
            y = rson[y];
        }
        if (lson[q] == y) lson[q] = lson[y];
        else rson[q] = lson[y];
        val[x] = val[y];
    }
    else {
        y = rson[x], q = x;
        while (lson[y]) {
            tot[y] --;
            q = y;
            y = lson[y];
        }
        if (lson[q] == y) lson[q] = rson[y];
        else rson[q] = rson[y];
        val[x] = val[y];
    }
}
int getRank(int num) {
    if (cnt == 0) return -1;
    // 判断num是否存在
    int x = 1;
    bool exist = false;
    while (true) {
        if (num == val[x]) {
            exist = true;
            break;
        }
        else if (num < val[x]) {
            if (lson[x]) x = lson[x];
            else break;
        }
        else {
            if (rson[x]) x = rson[x];
            else break;
        }
    }
    if (!exist) return -1;
    // 然后从上到下判断
    x = 1;
    int res = 0;
    while (true) {
        if (val[x] == num) {
            res ++;
            if (lson[x]) res += tot[lson[x]];
            break;
        }
        else if (val[x] < num) {
            res ++;
            if (lson[x]) res += tot[lson[x]];
            if (rson[x]) x = rson[x];
            else break;
        }
        else {
            if (lson[x]) x = lson[x];
            else break;
        }
    }
    return res;
}
int getNumByRank(int rk) {
    if (rk > cnt) return -1;
    int x = 1;
    while (true) {
        int left_num = 1;
        if (lson[x]) left_num += tot[lson[x]];
        if (left_num == rk) return val[x];
        else if (left_num > rk) x = lson[x];
        else {
            rk -= left_num;
            x = rson[x];
        }
    }
}
int getPre(int num) {
    int res = -1;
    if (cnt == 0) return -1;
    int x = 1;
    while (true) {
        if (val[x] < num) {
            res = val[x];
            if (rson[x]) x = rson[x];
            else break;
        }
        else {
            if (lson[x]) x = lson[x];
            else break;
        }
    }
    return res;
}
int getNext(int num) {
    int res = -1;
    if (cnt == 0) return -1;
    int x = 1;
    while (true) {
        if (val[x] > num) {
            res = val[x];
            if (lson[x]) x = lson[x];
            else break;
        }
        else {
            if (rson[x]) x = rson[x];
            else break;
        }
    }
    return res;
}
int n, op, x;
int main() {
    cin >> n;
    while (n --) {
        cin >> op >> x;
        if (op == 1) Insert(x);
        else if (op == 2) Delete(x);
        else if (op == 3) cout << getRank(x) << endl;
        else if (op == 4) cout << getNumByRank(x) << endl;
        else if (op == 5) cout << getPre(x) << endl;
        else if (op == 6) cout << getNext(x) << endl;
    }
    return 0;
}

使用 set 来实现上述功能的代码:

#include <bits/stdc++.h>
using namespace std;
set<int> st;
int n, op, x;
int main() {
    cin >> n;
    while (n --) {
        cin >> op >> x;
        if (op == 1) st.insert(x);
        else if (op == 2) {
            set<int>::iterator it = st.lower_bound(x);
            if (it != st.end() && (*it) == x) st.erase(it);
        }
        else if (op == 3) {
            set<int>::iterator it = st.lower_bound(x);
            if (it == st.end() || (*it) != x) cout << -1 << endl;
            else cout << distance(st.begin(), it) + 1 << endl;
        }
        else if (op == 4) {
            if (x > st.size()) cout << -1 << endl;
            else {
                set<int>::iterator it = st.begin();
                for (int i = 1; i < x; i ++) it ++;
                cout << (*it) << endl;
            }
        }
        else if (op == 5) {
            set<int>::iterator it = st.lower_bound(x);
            if (it == st.begin()) cout << -1 << endl;
            else {
                it --;
                cout << (*it) << endl;
            }
        }
        else {
            set<int>::iterator it = st.upper_bound(x);
            if (it == st.end()) cout << -1 << endl;
            else cout << (*it) << endl;
        }
    }
    return 0;
}

注意 distance() 函数的时间复杂度是 \(O(n)\) 的。

posted @ 2020-03-13 23:40  quanjun  阅读(381)  评论(0编辑  收藏  举报