二叉排序树基本操作
题目描述
编写一棵二叉排序树,来支持以下 \(6\) 种操作:
- 插入 \(x\) 数
- 删除 \(x\) 数(若有多个相同的数,因只删除一个;如果 \(x\) 不存在则不需要删除)
- 查询 \(x\) 数的排名(排名定义为比当前数小的数的个数 \(+1\) ;如果 \(x\) 不存在则输出 \(-1\))
- 查询排名为 \(x\) 的数(如果 \(x\) 大于树中元素个数,则输出 \(-1\))
- 求 \(x\) 的前驱(前驱定义为小于 \(x\),且最大的数;如果没有输出 \(-1\) )
- 求 \(x\) 的后继(后继定义为大于 \(x\),且最小的数;如果没有输出 \(-1\) )
输入格式
第一行为 \(n\)(\(1 \le n \le 10000\)),表示操作的个数,下面 \(n\) 行每行有两个数 \(\text{opt}\) 和 \(x\),\(\text{opt}\) 表示操作的序号( \(1 \leq \text{opt} \leq 6\) )
输出格式
对于操作 \(3,4,5,6\) 每行输出一个数,表示对应答案
样例输入
10
1 3
1 7
1 15
1 12
3 7
2 7
3 7
4 1
5 8
6 8
样例输出
2
-1
3
3
12
实现代码如下:
#include <bits/stdc++.h>
using namespace std;
const int maxn = 100010;
int lson[maxn], rson[maxn], val[maxn], sz, cnt, tot[maxn];
void Insert(int num) {
if (cnt == 0) { // 如果树为空,则直接插入根节点
cnt ++;
val[++sz] = num;
tot[sz] = 1;
return;
}
// 判断num是否存在
int x = 1;
while (true) {
if (num == val[x]) // 存在,直接返回
return;
else if (num < val[x]) {
if (lson[x]) x = lson[x];
else break;
}
else {
if (rson[x]) x = rson[x];
else break;
}
}
// 插入num
cnt ++;
val[++sz] = num;
tot[sz] = 1;
x = 1;
while (true) {
tot[x] ++;
if (num < val[x]) {
if (lson[x]) x = lson[x];
else {
lson[x] = sz;
break;
}
}
else {
if (rson[x]) x = rson[x];
else {
rson[x] = sz;
break;
}
}
}
}
void Delete(int num) {
if (sz == 0) return;
if (cnt == 1) {
if (val[1] != num) return;
cnt --;
lson[1] = rson[1] = 0;
return;
}
int x = 1, p = 0, y, q;
while (true) {
if (num == val[x]) break;
else if (num < val[x]) {
p = x;
if (!lson[x]) return;
x = lson[x];
}
else {
p = x;
if (!rson[x]) return;
x = rson[x];
}
}
cnt --;
x = 1; p = 0;
while (true) {
tot[x] --;
if (num == val[x]) break;
else if (num < val[x]) {
p = x;
x = lson[x];
}
else {
p = x;
x = rson[x];
}
}
if (!lson[x] && !rson[x]) { // 要删除的x是叶子节点
if (p) {
if (lson[p] == x) lson[p] = 0;
else rson[p] = 0;
}
}
else if (lson[x]) {
y = lson[x], q = x;
while (rson[y]) {
tot[y] --;
q = y;
y = rson[y];
}
if (lson[q] == y) lson[q] = lson[y];
else rson[q] = lson[y];
val[x] = val[y];
}
else {
y = rson[x], q = x;
while (lson[y]) {
tot[y] --;
q = y;
y = lson[y];
}
if (lson[q] == y) lson[q] = rson[y];
else rson[q] = rson[y];
val[x] = val[y];
}
}
int getRank(int num) {
if (cnt == 0) return -1;
// 判断num是否存在
int x = 1;
bool exist = false;
while (true) {
if (num == val[x]) {
exist = true;
break;
}
else if (num < val[x]) {
if (lson[x]) x = lson[x];
else break;
}
else {
if (rson[x]) x = rson[x];
else break;
}
}
if (!exist) return -1;
// 然后从上到下判断
x = 1;
int res = 0;
while (true) {
if (val[x] == num) {
res ++;
if (lson[x]) res += tot[lson[x]];
break;
}
else if (val[x] < num) {
res ++;
if (lson[x]) res += tot[lson[x]];
if (rson[x]) x = rson[x];
else break;
}
else {
if (lson[x]) x = lson[x];
else break;
}
}
return res;
}
int getNumByRank(int rk) {
if (rk > cnt) return -1;
int x = 1;
while (true) {
int left_num = 1;
if (lson[x]) left_num += tot[lson[x]];
if (left_num == rk) return val[x];
else if (left_num > rk) x = lson[x];
else {
rk -= left_num;
x = rson[x];
}
}
}
int getPre(int num) {
int res = -1;
if (cnt == 0) return -1;
int x = 1;
while (true) {
if (val[x] < num) {
res = val[x];
if (rson[x]) x = rson[x];
else break;
}
else {
if (lson[x]) x = lson[x];
else break;
}
}
return res;
}
int getNext(int num) {
int res = -1;
if (cnt == 0) return -1;
int x = 1;
while (true) {
if (val[x] > num) {
res = val[x];
if (lson[x]) x = lson[x];
else break;
}
else {
if (rson[x]) x = rson[x];
else break;
}
}
return res;
}
int n, op, x;
int main() {
cin >> n;
while (n --) {
cin >> op >> x;
if (op == 1) Insert(x);
else if (op == 2) Delete(x);
else if (op == 3) cout << getRank(x) << endl;
else if (op == 4) cout << getNumByRank(x) << endl;
else if (op == 5) cout << getPre(x) << endl;
else if (op == 6) cout << getNext(x) << endl;
}
return 0;
}
使用 set 来实现上述功能的代码:
#include <bits/stdc++.h>
using namespace std;
set<int> st;
int n, op, x;
int main() {
cin >> n;
while (n --) {
cin >> op >> x;
if (op == 1) st.insert(x);
else if (op == 2) {
set<int>::iterator it = st.lower_bound(x);
if (it != st.end() && (*it) == x) st.erase(it);
}
else if (op == 3) {
set<int>::iterator it = st.lower_bound(x);
if (it == st.end() || (*it) != x) cout << -1 << endl;
else cout << distance(st.begin(), it) + 1 << endl;
}
else if (op == 4) {
if (x > st.size()) cout << -1 << endl;
else {
set<int>::iterator it = st.begin();
for (int i = 1; i < x; i ++) it ++;
cout << (*it) << endl;
}
}
else if (op == 5) {
set<int>::iterator it = st.lower_bound(x);
if (it == st.begin()) cout << -1 << endl;
else {
it --;
cout << (*it) << endl;
}
}
else {
set<int>::iterator it = st.upper_bound(x);
if (it == st.end()) cout << -1 << endl;
else cout << (*it) << endl;
}
}
return 0;
}
注意 distance()
函数的时间复杂度是 \(O(n)\) 的。