【最简单的平衡树】Treap
Treap
树是很有用的一种结构,加上不同的约束规则之后,就形成了各种特性鲜明的结构。
最基本的二叉树,加以约束之后,可以形成 BST、AVL、Heap、Splay、R-B Tree …… 等,适用于各种场景。
对于平衡树,种类有很多,有的严格平衡,每次某个子树上任意两个子树的高度差超过1就会进行调整;也有的弱平衡,两子树高度差不超过一倍就不会调整。平衡树很重要,常用于 map、set 等数据结构和数据库等基础设施。但是平衡树大多并不好写,因此一般使用标准库提供的套件来完成工作。这就导致了有些时候,我们想要定制一些操作的时候,难以在封装好的数据结构上进行操作。
本文介绍一种简单好写的平衡树,并给出模板代码,可用于实际解决查元素排名、查排名元素、查前驱、查后继,以及区间操作等。
Treap= Tree + heap
Treap 使用 BST 来提供二分检索特性,使用 Heap 来管理旋转操作,维护二叉树平衡。相比其他平衡树,多用了一个字段来存储堆的权重。堆权随机生成,由随机性来保证二叉树“大概率是平衡的”。
Treap 有多种实现,可以用指针,也可以用数组;可以使用旋转来保证平衡,也可以用分裂合并来做。OI Wiki 解释说无旋 Treap 具有支持序列化的优势,我们在写题的时候用不到这一点,可以直接上数组,减少指针操作。
参考代码1:数组版
这里使用了 OI Wiki 提供的模板代码,并修正了删除不存在元素时 size-- 的错误,另外还加入了 find / count / delall 操作。
注意一点,查询操作只需要传值,插入删除旋转操作传递的是引用,以便旋转操作修改根节点索引。
Treap 本身并不需要 size 字段,这里加入 size 字段是为了提供 查元素排名 和 排名元素 的操作。
为了清晰展示思路,函数用的都是递归操作。代码如下:
#include <cstdio>
#include <algorithm>
#define maxn 100005
#define INF (1 << 30)
struct treap {
// cnt 是元素重复次数,size 是子树结点数目(计入重复元素),rnd 是堆权
// l, r 自动初始化为0,作为空值;0 号结点是空结点
// cnt 默认0个元素,size 默认子树元素为0
int l[maxn], r[maxn], val[maxn], cnt[maxn], rnd[maxn], size[maxn];
int sz; // array size, used for insert
int rt; // tree root index
int ans;
void lrotate(int& k)
{
int t = r[k];
r[k] = l[t];
l[t] = k;
size[t] = size[k];
size[k] = size[l[k]] + size[r[k]] + cnt[k];
k = t;
}
void rrotate(int& k)
{
int t = l[k];
l[k] = r[t];
r[t] = k;
size[t] = size[k];
size[k] = size[l[k]] + size[r[k]] + cnt[k];
k = t;
}
void insert(int& k, int x)
{
if (!k) { // append to the end
sz++;
k = sz;
val[k] = x;
cnt[k] = 1;
size[k] = 1;
rnd[k] = rand();
return;
}
size[k]++;
if (val[k] == x) {
cnt[k]++;
} else if (val[k] < x) {
insert(r[k], x);
if (rnd[r[k]] < rnd[k])
lrotate(k);
} else {
insert(l[k], x);
if (rnd[l[k]] < rnd[k])
rrotate(k);
}
}
bool del(int& k, int x)
{
if (!k)
return false;
if (val[k] == x) {
if (cnt[k] > 1) {
cnt[k]--;
size[k]--;
return true;
}
if (l[k] == 0 || r[k] == 0) { // 元素已调整到链条结点或叶节点
k = l[k] + r[k];
return true;
} else if (rnd[l[k]] < rnd[r[k]]) {
rrotate(k);
return del(k, x);
} else {
lrotate(k);
return del(k, x);
}
} else if (val[k] < x) {
bool succ = del(r[k], x);
if (succ)
size[k]--;
return succ;
} else {
bool succ = del(l[k], x);
if (succ)
size[k]--;
return succ;
} // 先把结点转到能删除的位置再删除
}
int delall(int& k, int x)
{
if (!k)
return 0;
if (val[k] == x) {
if (l[k] == 0 || r[k] == 0) { // 元素已调整到链条结点或叶节点
int diff = cnt[k];
k = l[k] + r[k];
return diff;
} else if (rnd[l[k]] < rnd[r[k]]) {
rrotate(k);
return delall(k, x);
} else {
lrotate(k);
return delall(k, x);
}
} else if (val[k] < x) {
int diff = delall(r[k], x);
size[k] -= diff;
return diff;
} else {
int diff = delall(l[k], x);
size[k] -= diff;
return diff;
} // 先把结点转到能删除的位置再删除
}
int find(int k, int x) {
if (!k) return 0;
if (val[k] == x) {
return k;
} else if (x < val[k]) {
return find(l[k], x);
} else {
return find(r[k], x);
}
}
int count(int k, int x) {
k = find(k, x);
if (!k) return 0;
return cnt[k];
}
// 查元素排名:x是第几小的数
int queryrank(int k, int x)
{
if (!k)
return 0;
if (val[k] == x)
return size[l[k]] + 1;
else if (x > val[k]) {
return size[l[k]] + cnt[k] + queryrank(r[k], x);
} else
return queryrank(l[k], x);
}
// 查排名元素:第x小数
int querynum(int k, int x)
{
if (!k)
return 0; // 返回空
if (x <= size[l[k]])
return querynum(l[k], x);
else if (x > size[l[k]] + cnt[k])
return querynum(r[k], x - size[l[k]] - cnt[k]);
else
return val[k];
}
// 查前驱:刚好比x小的元素
void querypre(int k, int x)
{
if (!k)
return;
if (val[k] < x)
ans = k, querypre(r[k], x);
else
querypre(l[k], x);
}
// 查后继:刚好比x大的元素
void querysub(int k, int x)
{
if (!k)
return;
if (val[k] > x)
ans = k, querysub(l[k], x);
else
querysub(r[k], x);
}
} T;
int main()
{
srand(123);
int n;
scanf("%d", &n);
int opt, x;
for (int i = 1; i <= n; i++) {
scanf("%d%d", &opt, &x);
switch (opt) {
case 1:
T.insert(T.rt, x);
break;
case 2:
printf("del %d is %d\n", x, T.del(T.rt, x));
break;
case 3:
printf("delall %d count %d\n", x, T.delall(T.rt, x));
break;
case 4:
printf("rank of %d is %d\n", x, T.queryrank(T.rt, x));
break;
case 5:
printf("value of rank %d is %d\n", x, T.querynum(T.rt, x));
break;
case 6:
T.ans = 0;
T.querypre(T.rt, x);
printf("previous value of %d is %d\n", x, T.val[T.ans]);
break;
case 7:
T.ans = 0;
T.querysub(T.rt, x);
printf("successor value of %d is %d\n", x, T.val[T.ans]);
break;
default:
printf("invalid opt %d\n", opt);
}
}
return 0;
}
注意,这份代码也并不完美,存在一个问题:
添加元素的时候在数组末尾添加,删除元素并不会回收所占数组位置,形成一个个“空洞”,造成空间浪费。
每次删除的时候,都主动用最后一个元素来填补空洞是不可接受的,这会让 delete 操作的时间复杂度上升到 O(N);可以接收的解决办法是记录空洞个数,当“”空洞率”达到一定阈值之后启动整理,一次性压缩所有空洞。
当然,写题的时候是不需要考虑这么多的。
参考代码2:指针版
指针版解决了内存泄漏的问题,经过设计也并不复杂。这里首先参考了 程序员小灰的文章,据说投稿人只有13岁,值得鼓励。当然,也许小灰并没有对投稿进行过审核就在各大平台发了出来,代码中问题多多,包括但不限于:
- left_rotate 给的是 right_rotate 的代码;
- lson 和 rson 写成数组形式时,仍然出现了 rson 的字眼;
- query_rank 和 query_value 都没有针对非法值的处理,运行时会崩溃;
- query_rank 和 query_value 的结果竟然对应不上;
- query_value 写错了左右子孩子;
- query_value 非递归版进左子树的条件少了一个等于号;
- ...
可以看到,一些错误编译的时候就会出现,一些错误运行的时候会崩掉,非常明显。可以下结论,小灰根本没有运行过投稿人给出的代码,这样搞着实不大行啊……
指针实现的关键点,同样是引用的使用。具体来讲就是 typedef Node* Tree;
,这让代码可读性明显提升。
以下是修正过的代码:
#include <cstdio>
#include <algorithm>
using namespace std;
#define Inf 0x3f3f3f3f
typedef struct Node {
Node(int v) {
val = v, cnt = size = 1, fac = rand();
lson = rson = nullptr;
}
// 值 个数 子树大小 优先级
int val, cnt, size, fac;
Node *lson, *rson;
// 更新当前子树大小
inline void push_up() {
size = cnt;
if (lson != nullptr) size += lson->size;
if (rson != nullptr) size += rson->size;
}
}* Tree;
inline int size(Tree t) { return t == nullptr ? 0 : t->size; }
inline void right_rotate(Tree &a) {
Tree b = a->lson;
a->lson = b->rson, b->rson = a, a = b;
a->rson->push_up(), a->push_up();
}
inline void left_rotate(Tree &a) {
Tree b = a->rson;
a->rson = b->lson, b->lson = a, a = b;
a->lson->push_up(), a->push_up();
}
// // 也可以将lson和rson写成一个数组son,然后将左旋和右旋写成一个函数:
// // 注意左旋、右旋中的代码传的参数a需要传引用,因为最后a也要更新
// inline void rotate(Tree &a, int f) {
// Tree b = a->son[f^1];
// a->son[f^1] = b->son[f], b->son[f] = a, a = b;
// a->son[f]->push_up(), a->push_up();
// }
inline void insert(Tree &rt, int val) {
if (rt == nullptr) {
rt = new Node(val);
return;
}
if (val == rt->val) {
rt->cnt++; // 已经有这个点了
} else if (val < rt->val) {
insert(rt->lson, val);
if (rt->fac < rt->lson->fac) right_rotate(rt);
} else {
insert(rt->rson, val);
if (rt->fac < rt->rson->fac) left_rotate(rt);
}
rt->push_up();
}
inline void del(Tree &rt, int val) {
if (rt == nullptr) return; // 没找到
if (val == rt->val) {
if (rt->cnt > 1) {
rt->cnt--, rt->push_up();
return;
}
if (rt->lson == nullptr && rt->rson == nullptr) {
delete rt; rt = nullptr;
return;
} // 叶结点
else {
if (rt->rson == nullptr || (rt->lson != nullptr && rt->lson->fac > rt->rson->fac)) {
right_rotate(rt), del(rt->rson, val);
} // 右子树小,右旋删除
else {
left_rotate(rt), del(rt->lson, val);
} // 左子树小,左旋删除
}
}
else if (val < rt->val) del(rt->lson, val);
else del(rt->rson, val);
rt->push_up();
}
// 询问有多少个数小于等于val(也就相当于查询排名)
inline int query_rank(Tree p, int val) {
int rank = 1; // 有效值的排名,从1开始
while (p != nullptr) {
if (val == p->val) return rank + size(p->lson);
else if (val < p->val) p = p->lson;
else rank += size(p->lson) + p->cnt, p = p->rson;
}
// return rank;
return 0; // 排名为0代表无效值
}
// // query_rank 的递归实现有一个缺点,那就是对于无效值,其返回值并不固定,无法据此判断当前值是否存在于这棵树之中
// inline int query_rank(Tree p, int val) {
// if (p == nullptr) return 0; // 排名为0代表无效值
// if (val == p->val) return 1 + size(p->lson); // 有效值的排名,从1开始
// if (val < p->val) return query_rank(p->lson, val);
// return size(p->lson) + p->cnt + query_rank(p->rson, val);
// }
#define INVALID_RANK 0x7f7f7f7f
// // query_value 递归和非递归实现,均能正常工作
// inline int query_value(Tree p, int rank) {
// // if (rank < 0 || rank > size(p)) return INVALID_RANK;
// while (p != nullptr && rank) {
// if (rank <= size(p->lson)) p = p->lson;
// else if (rank <= size(p->lson) + p->cnt) return p->val;
// else rank -= size(p->lson) + p->cnt, p = p->rson;
// }
// return INVALID_RANK;
// }
inline int query_value(Tree p, int rank) {
if (p == nullptr) return INVALID_RANK; // printf("rank %d not exist!\n", rank);
if (rank <= size(p->lson)) return query_value(p->lson, rank);
if (rank <= size(p->lson) + p->cnt) return p->val;
return query_value(p->rson, rank - size(p->lson) - p->cnt);
}
// 返回比val小的最大的有效值
inline int query_prev(Tree p, int val) {
int pre = -Inf;
while (p != nullptr) {
if (p->val < val) pre = p->val, p = p->rson;
else p = p->lson;
}
return pre;
}
// 返回比val大的最小的有效值
inline int query_next(Tree p, int val) {
int suf = Inf;
while (p != nullptr) {
if (p->val > val) suf = p->val, p = p->lson;
else p = p->rson;
}
return suf;
}
// ----------------------------------------------------------------------------
int main()
{
srand(123);
int n;
scanf("%d", &n);
Tree rt = nullptr; // 必须初始化,不能出现悬空指针
int opt, x;
for (int i = 1; i <= n; i++) {
scanf("%d%d", &opt, &x);
switch (opt) {
case 1:
insert(rt, x);
break;
case 2:
del(rt, x);
// printf("del %d is %d\n", x, del(rt, x));
break;
case 3:
// printf("delall %d count %d\n", x, delall(rt, x));
break;
case 4:
printf("rank of %d is %d\n", x, query_rank(rt, x));
break;
case 5:
printf("value of rank %d is %d\n", x, query_value(rt, x));
break;
case 6:
printf("previous value of %d is %d\n", x, query_prev(rt, x));
break;
case 7:
printf("successor value of %d is %d\n", x, query_next(rt, x));
break;
default:
printf("invalid opt %d\n", opt);
}
}
return 0;
}