[lnsyoj118/luoguP3369]普通平衡树
题意
维护一个数据结构,要求支持插入,删除,根据排名查数,根据数查排名,查询前驱,查询后继\(6\)个操作
sol
考虑到后四个查询的操作,会发现使用二叉搜索树(BST)完全可以实现
为了完成这四个操作,需要在每个节点记录\(3\)个值:
- \(key\) 表示当前节点的数
- \(cnt\) 表示当前节点的数的个数(为了防止出现同一数字出现多次)
- \(size\) 表示当前子树的数的个数(为了方便查询排名)
根据排名查数
当处于节点\(u\)时,设当前需要查询的排名为\(rank\),如果此时节点\(u\)为空节点,说明不存在该数,否则分情况讨论:
- 如果\(u.lson.size \ge rank\),说明此时要查询的数一定位于\(u\)的左子树,因此答案为左子树中排名为\(rank\)的数
- 如果\(u.lson.size + u.cnt \ge rank\),说明此时要查询的数为\(u.key\),因此答案就为\(u.key\)
- 前两条均不满足,则说明此时要查询的数一定位于\(u\)的右子树,又由于需要去除掉左子树和\(u\)的所有数,因此答案为右子树中排名为\(rank - u.lson.size - u.cnt\)的数
代码
int get_key(int u, int rank){
if (!u) return INF;
if (rank <= tr[tr[u].l].size) return get_key(tr[u].l, rank);
if (rank <= tr[tr[u].l].size + tr[u].cnt) return tr[u].key;
return get_key(tr[u].r, rank - tr[tr[u].l].size - tr[u].cnt);
}
根据数查排名
当处于节点\(u\)时,设当前需要查询的数为\(x\),如果此时节点\(u\)为空节点,说明不存在该数,否则分情况讨论:
- 如果\(x < u.key\),说明此时要查询的数一定位于\(u\)的左子树,因此答案为左子树中数\(x\)的排名
- 如果\(x = u.key\),说明此时要查询的数为\(u.key\),因此答案为\(u.lson.size + 1\)
- 前两条均不满足,则说明此时要查询的数一定位于\(u\)的右子树,又由于需要加上左子树和\(u\)的所有数,因此答案为右子树中数\(x\)的排名\(+u.lson.size + u.cnt\)
代码
int get_rank(int u, int key){
if (!u) return 0;
if (key < tr[u].key) return get_rank(tr[u].l, key);
if (key == tr[u].key) return tr[tr[u].l].size + 1;
return tr[tr[u].l].size + tr[u].cnt + get_rank(tr[u].r, key);
}
需要注意的是,部分时候为了方便,我们会在BST中加入两个哨兵节点\(-\infty\)与\(+\infty\),此时由于\(-\infty\)的存在,根据排名查数时的\(rank\)需要\(+1\),而根据数查排名时的查得的答案需要\(-1\)
查询前驱
当处于节点\(u\)时,设当前需要查询的数为\(x\),如果此时节点\(u\)为空节点,说明不存在该数,否则分情况讨论:
- 若\(x \le u.key\),说明此时要查询的数一定不位于\(u\)的右子树,因此答案为左子树中数\(x\)的前驱
- 若\(x > u.key\),说明此时要查询的数可能为\(u.key\),也可能位于\(u\)的右子树,因此答案为右子树中数\(x\)的前驱与\(u.key\)中的最大值
代码
int get_prev(int u, int key){
if (!u) return -INF;
if (tr[u].key >= key) return get_prev(tr[u].l, key);
return max(tr[u].key, get_prev(tr[u].r, key));
}
查询后继
当处于节点\(u\)时,设当前需要查询的数为\(x\),如果此时节点\(u\)为空节点,说明不存在该数,否则分情况讨论:
- 若\(x \ge u.key\),说明此时要查询的数一定不位于\(u\)的左子树,因此答案为右子树中数\(x\)的前驱
- 若\(x < u.key\),说明此时要查询的数可能为\(u.key\),也可能位于\(u\)的左子树,因此答案为左子树中数\(x\)的前驱与\(u.key\)中的最小值
代码
int get_next(int u, int key){
if (!u) return INF;
if (tr[u].key <= key) return get_next(tr[u].r, key);
return min(tr[u].key, get_next(tr[u].l, key));
}
这样一来,这四个查询操作及两个修改操作的复杂度为\(O(h)\),\(h\)为BST高度。在随机数据下,\(h\)趋向于\(\log n\),但由于BST容易被卡的优秀性质,只需递增/递减数据就可以将BST卡成一条链,从而使\(h=n\),因此,我们需要一些手段来使BST的\(h\)无论何时都接近于\(\log n\),平衡树应运而生
旋转
BST有一条很好的性质:容易被卡中序遍历是单调递增的,反过来也成立,如果我们可以通过一些操作,使中序遍历不变,那么这棵树仍是本质相同的BST,而这个能够使中序遍历不变的操作即为旋转,旋转是几乎所有平衡树都需要使用的操作(部分除外,如FHQ-Treap)
两图中序遍历都为\(A,Q,B,P,C\)
在执行\(zig\)操作时,需要进行三次改变:\(p.lson \to q.rson(B), q.rson \to p, p \to q\)
同理,在执行\(zag\)操作时,也需要进行三次改变:\(q.rson \to p.lson(B), p.lson \to q, q \to p\)
代码
void zig(int &u){
int q = tr[u].l;
tr[u].l = tr[q].r, tr[q].r = u, u = q;
}
void zag(int &u){
int q = tr[u].r;
tr[u].r = tr[q].l, tr[q].l = u, u = q;
}
需要注意的是,这里的\(u\)指代的是根节点或某个节点的子节点,当执行\(zig\)或\(zag\)时,所对应的节点也要改变,因此需要在函数中传递引用。旋转操作可以视为是BST上三条边所指节点的交换操作
Treap
Treap是OI中较常用的一种平衡树
Treap是Tree和Heap的结合体,它的原理非常简单粗暴:既然BST在随机数据下趋于\(\log n\),那么我们就把所有数据打乱顺序再插入就好了。显然,在\(99.99\%\)的情况之下,这种方法都是有效的。不过因为大多数平衡树解决的问题都是在线问题,因此我们无法简单地将数据打乱。
Treap给出的解决方案是这样的:对于每一个节点,在插入时赋予它一个随机权值\(val\),由于可以通过\(zig\)和\(zag\)操作将BST的任一一对父子节点交换而不改变BST的本质,因此我们可以参考二叉堆,插入到对应位置后再向上调整,直到BST中的\(val\)仍然满足二叉堆的性质
对于插入操作,我们先将一个节点插入BST中,然后从下往上判断它是否需要调序;而对于删除操作,我们在BST中找到该节点后,为了方便操作,我们将该节点先调整到叶子结点上,再进行删除。具体代码见下:
void insert(int &u, int key){
if (!u) u = create(key); // 没有该节点的话,就创建一个新节点
else if (key == tr[u].key) tr[u].cnt ++ ; // 否则直接在节点上添加标记
else if (key < tr[u].key){
insert(tr[u].l, key);
if (tr[tr[u].l].val < tr[u].val) zig(u); // 向上调序
}
else {
insert(tr[u].r, key);
if (tr[tr[u].r].val < tr[u].val) zag(u); // 向上调序
}
}
void erase(int &u, int key){
if (!u) return ; // 没有该节点的话,无需处理
else if (key == tr[u].key){
if (tr[u].cnt > 1) tr[u].cnt -- ; // 如果存在多个标记,直接删除标记
else if (tr[u].l || tr[u].r){
if (!tr[u].r || tr[tr[u].l].val > tr[tr[u].r].val){
zig(u); // 先向下调序
erase(tr[u].r, key);
}
else{
zag(u); // 先向下调序
erase(tr[u].l, key);
}
}
else u = 0; // 调到叶子节点后直接删除
}
else if (key < tr[u].key) erase(tr[u].l, key);
else erase(tr[u].r, key);
}
需要注意的是,本题的\(size\)是会在旋转、插入、删除操作中随时改变的,类比线段树,我们还需要一个方法来根据子结点的数据反推节点的\(size\),即PUSHUP
代码:
void pushup(int u){
tr[u].size = tr[tr[u].l].size + tr[tr[u].r].size + tr[u].cnt;
}
这样的话,我们就通过精巧的操作使BST基本平衡,平均时间复杂度也随之下降为\(O(n \log n)\),不过值得注意的是,其最坏复杂度仍为\(O(n^2)\),只是如果真的卡出来了,概率堪比十连十金
代码
#include <iostream>
#include <cstring>
#include <algorithm>
#include <cstdlib>
using namespace std;
const int N = 100005, INF = 0x3f3f3f3f;
struct Node{
int l, r;
int key, val;
int cnt, size;
}tr[N];
int root, idx;
int n;
int create(int key){
tr[ ++ idx].key = key;
tr[idx].val = rand();
tr[idx].cnt = tr[idx].size = 1;
return idx;
}
void pushup(int u){
tr[u].size = tr[tr[u].l].size + tr[tr[u].r].size + tr[u].cnt;
}
void zig(int &u){
int q = tr[u].l;
tr[u].l = tr[q].r, tr[q].r = u, u = q;
pushup(tr[u].r);
}
void zag(int &u){
int q = tr[u].r;
tr[u].r = tr[q].l, tr[q].l = u, u = q;
pushup(tr[u].l);
}
void build(){
create(-INF), create(INF);
root = 1, tr[1].r = 2;
pushup(root);
}
void insert(int &u, int key){
if (!u) u = create(key);
else if (key == tr[u].key) tr[u].cnt ++ ;
else if (key < tr[u].key){
insert(tr[u].l, key);
if (tr[tr[u].l].val < tr[u].val) zig(u);
}
else {
insert(tr[u].r, key);
if (tr[tr[u].r].val < tr[u].val) zag(u);
}
pushup(u);
}
void erase(int &u, int key){
if (!u) return ;
else if (key == tr[u].key){
if (tr[u].cnt > 1) tr[u].cnt -- ;
else if (tr[u].l || tr[u].r){
if (!tr[u].r || tr[tr[u].l].val > tr[tr[u].r].val){
zig(u);
erase(tr[u].r, key);
}
else{
zag(u);
erase(tr[u].l, key);
}
}
else u = 0;
}
else if (key < tr[u].key) erase(tr[u].l, key);
else erase(tr[u].r, key);
pushup(u);
}
int get_rank(int u, int key){
if (!u) return 0;
if (key < tr[u].key) return get_rank(tr[u].l, key);
if (key == tr[u].key) return tr[tr[u].l].size + 1;
return tr[tr[u].l].size + tr[u].cnt + get_rank(tr[u].r, key);
}
int get_key(int u, int rank){
if (!u) return INF;
if (rank <= tr[tr[u].l].size) return get_key(tr[u].l, rank);
if (rank <= tr[tr[u].l].size + tr[u].cnt) return tr[u].key;
return get_key(tr[u].r, rank - tr[tr[u].l].size - tr[u].cnt);
}
int get_prev(int u, int key){
if (!u) return -INF;
if (tr[u].key >= key) return get_prev(tr[u].l, key);
return max(tr[u].key, get_prev(tr[u].r, key));
}
int get_next(int u, int key){
if (!u) return INF;
if (tr[u].key <= key) return get_next(tr[u].r, key);
return min(tr[u].key, get_next(tr[u].l, key));
}
int main(){
scanf("%d", &n);
build();
while (n -- ){
int op, x;
scanf("%d%d", &op, &x);
switch(op){
case 1: insert(root, x); break;
case 2: erase(root, x); break;
case 3: printf("%d\n", get_rank(root, x) - 1); break;
case 4: printf("%d\n", get_key(root, x + 1)); break;
case 5: printf("%d\n", get_prev(root, x)); break;
case 6: printf("%d\n", get_next(root, x)); break;
default: break;
}
}
return 0;
}
蒟蒻犯的若至错误
- \(zig\)和\(zag\)的时候没有PUSHUP导致整颗BST的\(size\)都计算错误