Treap 学习笔记
Treap 学习笔记
突然想起来鸽了很久没学的平衡树,今天去网上查找各种资料学了 Treap。
大概是全机房最后一个学会 Treap 的罢/kk。
以 P3369【模板】普通平衡树 为例,写一写自己对 Treap 的理解。
- Treap = Tree + Heap,即在 BST 的基础上多出了堆的性质。
- Treap 的平衡来源于随机给节点赋优先级,插入节点时通过合理的旋转使优先级满足堆的性质。
这道题要我们支持以下六个操作:
- 插入 \(x\)
- 删除 \(x\)
- 查询 \(x\) 的排名。
- 查询排名为 \(x\) 的数。
- 查询 \(x\) 的前驱。
- 查询 \(x\) 的后继。
首先存一下需要维护的东西:
int root; //目前的根
int tot; //节点的个数
struct treap{
int son[2];//son[0]左儿子,son[1]右儿子
int val; //该节点所存储的值
int cnt; //该节点所存储的值的个数
int siz; //以该节点为根的子树中有多少个数
int dat; //随机赋的优先级
}node[N];
维护节点的 siz
:
void pushup(int rt){
//该节点siz=左右子树siz之和+该节点cnt
node[rt].siz=node[node[rt].son[0]].siz+node[node[rt].son[1]].siz+node[rt].cnt;
}
新建节点:
int build(int val){//新建值为val的节点
tot++;
node[tot].val=val;
node[tot].cnt=node[tot].siz=1;//新建时一定已经走到了叶节点,所以cnt和siz都赋值为1
node[tot].son[0]=node[tot].son[1]=0;
node[tot].dat=rand();//这里据说用库函数有可能被卡,可以自己手写rand()
}
关键操作,旋转:
void rotate(int &rt,int d){//rt:所要旋转的根节点,d=0/1表示左旋/右旋
int k=node[rt].son[d^1];//取出与旋转方向相反的儿子节点
node[rt].son[d^1]=node[k].son[d];//把这个位置的节点改为其前驱/后继
node[k].son[d]=rt;//把旋转上去的节点连向根
rt=k;//根变成了旋转上去的节点
pushup(node[rt].son[d]);//维护旋转下去的节点
pushup(rt);//维护新的根
}
插入节点:
void insert(int &rt,int val){//rt:当前遍历到的根节点,val:要插入的数
if(!rt){
rt=build(val);//如果这个节点不存在,新建一个节点维护
return;
}
node[rt].siz++;//插入一个数一定会使当前遍历到的节点siz增加1
if(node[rt].val==val) node[rt].cnt++;//如果当前节点val==要插入的val,cnt增加1
else if(node[rt].val>val){
insert(node[rt].son[0],val);//当前节点val>要插入的val,在左儿子递归找
if(node[rt].dat<node[node[rt].son[0]].dat) rotate(rt,1);//维护优先级
}else{
insert(node[rt].son[1],val);//当前节点val<要插入的val,在右儿子递归找
if(node[rt].dat<node[node[rt].son[1]].dat) rotate(rt,0);//维护优先级
}
pushup(rt);//维护siz
}
另外,在所有操作开始前,应插入一个极大值和极小值以防越界:
void init(){
root=build(-INF);
insert(root,INF);
}
删除节点:
void del(int &rt,int val){//rt:当前遍历到的根节点,val:要删除的数
if(!rt) return;//遍历到空节点,证明没有要删除的数
if(node[rt].val==val){//找到了这个数
//如果这个数的个数>1(>=2),删除一个后这个数仍然存在,直接cnt--并维护siz
if(node[rt].cnt>1){
node[rt].cnt--;
pushup(rt);
return;
}
if(node[rt].son[0]||node[rt].son[1]){
//如果只有左儿子或左儿子优先级高于右儿子
if(!node[rt].son[1]||node[node[rt].son[0]].dat>node[node[rt].son[1]].dat){
//把这个节点右旋(让左儿子变为现在的根)并删除
rotate(rt,1);
del(node[rt].son[1],val);
}else{
//否则,把这个节点左旋(让右儿子变为现在的根)并删除
rotate(rt,0);
del(node[rt].son[0],val);
}
}else rt=0;//如果这个节点不存在儿子,直接变为空节点
}else if(node[rt].val>val) del(node[rt].son[0],val);//当前节点val>要删除的val,在左儿子递归找
else del(node[rt].son[1],val);//当前节点val<要删除的val,在右儿子递归找
pushup(rt);//维护新根的siz
}
查询值对应的排名:
//这里查询到的排名是包含-INF的排名,实际排名应-1
int getrk(int val){//查询val的排名
int rt=root,res=0;
while(rt){
//如果当前节点val=查询的val,返回左子树siz+已跳过的res+1
if(node[rt].val==val) return node[node[rt].son[0]].siz+res+1;
//如果当前节点val<查询的val,跳过左子树siz个节点,在右子树查询
if(node[rt].val<val){
res+=node[node[rt].son[0]].siz+node[rt].cnt;
rt=node[rt].son[1];
}else rt=node[rt].son[0];//如果当前节点val>查询的val,在左子树查询
}
return res;
}
查询排名对应的值:
//这里的参数rk应该包括-INF带来的影响,传参时应传rk+1
int getval(int rk){//查询排名为rk的数
int rt=root;
//这里的写法非常优秀,刚好以rk的相对大小分类
while(rt){
//如果查询的rk<=左子树siz,在左子树查询rk
if(node[node[rt].son[0]].siz>=rk) rt=node[rt].son[0];
//否则,如果查询的rk在(左子树siz,左子树siz+当前节点cnt]内,结果即为当前节点val
else if(node[node[rt].son[0]].siz+node[rt].cnt>=rk) return node[rt].val;
//否则,在右子树中查询rk-(左子树siz+当前节点cnt)
else{
rk-=(node[node[rt].son[0]].siz+node[rt].cnt);
rt=node[rt].son[1];
}
}
}
查询前驱:
int getpre(int val){//查询val的前驱
int rt=root,res;
while(rt){
//如果当前节点val<查询的val,返回值更新为当前节点val,在右子树查询
if(node[rt].val<val){
res=node[rt].val;
rt=node[rt].son[1];
//否则,在左子树查询
}else rt=node[rt].son[0];
}
return res;
}
查询后继:
int getnxt(int val){//查询val的后继
int rt=root,res;
while(rt){
//如果当前节点val>查询的val,返回值更新为当前节点val,在左子树查询
if(node[rt].val>val){
res=node[rt].val;
rt=node[rt].son[0];
//否则,在右子树查询
}else rt=node[rt].son[1];
}
return res;
}
\(\rm Code:\)
#include <bits/stdc++.h>
using namespace std;
const int N=100010;
const int INF=0x3f3f3f3f;
inline int read(){
int x=0;bool f=false;char ch=getchar();
while(!isdigit(ch)){
if(ch=='-') f=true;
ch=getchar();
}
while(isdigit(ch)){
x=(x<<3)+(x<<1)+(ch^48);
ch=getchar();
}
return f?-x:x;
}
struct treap{
int son[2];
int val,cnt,siz,dat;
}node[N];
int t,root,tot;
int randdd(){
static unsigned int random=114514;
return (random*1919810ull)%2147483647;
}
void pushup(int rt){
node[rt].siz=node[node[rt].son[0]].siz+node[node[rt].son[1]].siz+node[rt].cnt;
}
int build(int val){
tot++;
node[tot].val=val;
node[tot].cnt=node[tot].siz=1;
node[tot].son[0]=node[tot].son[1]=0;
node[tot].dat=rand();
return tot;
}
void rotate(int &rt,int d){
int k=node[rt].son[d^1];
node[rt].son[d^1]=node[k].son[d];
node[k].son[d]=rt;
rt=k;
pushup(node[rt].son[d]);
pushup(rt);
}
void insert(int &rt,int val){
if(!rt){
rt=build(val);
return;
}
node[rt].siz++;
if(node[rt].val==val) node[rt].cnt++;
else if(node[rt].val>val){
insert(node[rt].son[0],val);
if(node[rt].dat<node[node[rt].son[0]].dat) rotate(rt,1);
}else{
insert(node[rt].son[1],val);
if(node[rt].dat<node[node[rt].son[1]].dat) rotate(rt,0);
}
pushup(rt);
}
void init(){
root=build(-INF);
insert(root,INF);
}
void del(int &rt,int val){
if(!rt) return;
if(node[rt].val==val){
if(node[rt].cnt>1){
node[rt].cnt--;
pushup(rt);
return;
}
if(node[rt].son[0]||node[rt].son[1]){
if(!node[rt].son[1]||node[node[rt].son[0]].dat>node[node[rt].son[1]].dat){
rotate(rt,1);
del(node[rt].son[1],val);
}else{
rotate(rt,0);
del(node[rt].son[0],val);
}
}else rt=0;
}else if(node[rt].val>val) del(node[rt].son[0],val);
else del(node[rt].son[1],val);
pushup(rt);
}
int getrk(int val){
int rt=root,res=0;
while(rt){
if(node[rt].val==val) return node[node[rt].son[0]].siz+res+1;
if(node[rt].val<val){
res+=node[node[rt].son[0]].siz+node[rt].cnt;
rt=node[rt].son[1];
}else rt=node[rt].son[0];
}
return res;
}
int getval(int rk){
int rt=root;
while(rt){
if(node[node[rt].son[0]].siz>=rk) rt=node[rt].son[0];
else if(node[node[rt].son[0]].siz+node[rt].cnt>=rk) return node[rt].val;
else{
rk-=(node[node[rt].son[0]].siz+node[rt].cnt);
rt=node[rt].son[1];
}
}
}
int getpre(int val){
int rt=root,res;
while(rt){
if(node[rt].val<val){
res=node[rt].val;
rt=node[rt].son[1];
}else rt=node[rt].son[0];
}
return res;
}
int getnxt(int val){
int rt=root,res;
while(rt){
if(node[rt].val>val){
res=node[rt].val;
rt=node[rt].son[0];
}else rt=node[rt].son[1];
}
return res;
}
signed main(){
srand(time(0));
t=read();
init();
while(t--){
int opt,x;
opt=read();x=read();
if(opt==1) insert(root,x);
if(opt==2) del(root,x);
if(opt==3) printf("%d\n",getrk(x)-1);
if(opt==4) printf("%d\n",getval(x+1));
if(opt==5) printf("%d\n",getpre(x));
if(opt==6) printf("%d\n",getnxt(x));
}
return 0;
}
contact me on QQ (601585974 布鲁)