Treap 学习笔记

Treap 学习笔记

突然想起来鸽了很久没学的平衡树,今天去网上查找各种资料学了 Treap。

大概是全机房最后一个学会 Treap 的罢/kk。

P3369【模板】普通平衡树 为例,写一写自己对 Treap 的理解。

  • Treap = Tree + Heap,即在 BST 的基础上多出了堆的性质。
  • Treap 的平衡来源于随机给节点赋优先级,插入节点时通过合理的旋转使优先级满足堆的性质。

这道题要我们支持以下六个操作:

  • 插入 \(x\)
  • 删除 \(x\)
  • 查询 \(x\) 的排名。
  • 查询排名为 \(x\) 的数。
  • 查询 \(x\) 的前驱。
  • 查询 \(x\) 的后继。

首先存一下需要维护的东西:

int root;      //目前的根
int tot;       //节点的个数
struct treap{
    int son[2];//son[0]左儿子,son[1]右儿子
    int val;   //该节点所存储的值
    int cnt;   //该节点所存储的值的个数
    int siz;   //以该节点为根的子树中有多少个数
    int dat;   //随机赋的优先级
}node[N];

维护节点的 siz

void pushup(int rt){
    //该节点siz=左右子树siz之和+该节点cnt
    node[rt].siz=node[node[rt].son[0]].siz+node[node[rt].son[1]].siz+node[rt].cnt;
}

新建节点:

int build(int val){//新建值为val的节点
    tot++;
    node[tot].val=val;
    node[tot].cnt=node[tot].siz=1;//新建时一定已经走到了叶节点,所以cnt和siz都赋值为1
    node[tot].son[0]=node[tot].son[1]=0;
    node[tot].dat=rand();//这里据说用库函数有可能被卡,可以自己手写rand()
}

关键操作,旋转:

void rotate(int &rt,int d){//rt:所要旋转的根节点,d=0/1表示左旋/右旋
    int k=node[rt].son[d^1];//取出与旋转方向相反的儿子节点
	node[rt].son[d^1]=node[k].son[d];//把这个位置的节点改为其前驱/后继
	node[k].son[d]=rt;//把旋转上去的节点连向根
	rt=k;//根变成了旋转上去的节点
	pushup(node[rt].son[d]);//维护旋转下去的节点
	pushup(rt);//维护新的根
}

插入节点:

void insert(int &rt,int val){//rt:当前遍历到的根节点,val:要插入的数
	if(!rt){
		rt=build(val);//如果这个节点不存在,新建一个节点维护
		return;
	}
	node[rt].siz++;//插入一个数一定会使当前遍历到的节点siz增加1
	if(node[rt].val==val) node[rt].cnt++;//如果当前节点val==要插入的val,cnt增加1
	else if(node[rt].val>val){
		insert(node[rt].son[0],val);//当前节点val>要插入的val,在左儿子递归找
		if(node[rt].dat<node[node[rt].son[0]].dat) rotate(rt,1);//维护优先级
	}else{
		insert(node[rt].son[1],val);//当前节点val<要插入的val,在右儿子递归找
		if(node[rt].dat<node[node[rt].son[1]].dat) rotate(rt,0);//维护优先级
	}
	pushup(rt);//维护siz
}

另外,在所有操作开始前,应插入一个极大值和极小值以防越界:

void init(){
	root=build(-INF);
	insert(root,INF);
}

删除节点:

void del(int &rt,int val){//rt:当前遍历到的根节点,val:要删除的数
	if(!rt) return;//遍历到空节点,证明没有要删除的数
	if(node[rt].val==val){//找到了这个数
		//如果这个数的个数>1(>=2),删除一个后这个数仍然存在,直接cnt--并维护siz
		if(node[rt].cnt>1){
			node[rt].cnt--;
			pushup(rt);
			return;
		}
		if(node[rt].son[0]||node[rt].son[1]){
			//如果只有左儿子或左儿子优先级高于右儿子
			if(!node[rt].son[1]||node[node[rt].son[0]].dat>node[node[rt].son[1]].dat){
				//把这个节点右旋(让左儿子变为现在的根)并删除
				rotate(rt,1);
				del(node[rt].son[1],val);
			}else{
				//否则,把这个节点左旋(让右儿子变为现在的根)并删除
				rotate(rt,0);
				del(node[rt].son[0],val);
			}
		}else rt=0;//如果这个节点不存在儿子,直接变为空节点
	}else if(node[rt].val>val) del(node[rt].son[0],val);//当前节点val>要删除的val,在左儿子递归找
	else del(node[rt].son[1],val);//当前节点val<要删除的val,在右儿子递归找
	pushup(rt);//维护新根的siz
}

查询值对应的排名:

//这里查询到的排名是包含-INF的排名,实际排名应-1
int getrk(int val){//查询val的排名
	int rt=root,res=0;
	while(rt){
		//如果当前节点val=查询的val,返回左子树siz+已跳过的res+1
		if(node[rt].val==val) return node[node[rt].son[0]].siz+res+1;
		//如果当前节点val<查询的val,跳过左子树siz个节点,在右子树查询
		if(node[rt].val<val){
			res+=node[node[rt].son[0]].siz+node[rt].cnt;
			rt=node[rt].son[1];
		}else rt=node[rt].son[0];//如果当前节点val>查询的val,在左子树查询
	}
	return res;
}

查询排名对应的值:

//这里的参数rk应该包括-INF带来的影响,传参时应传rk+1
int getval(int rk){//查询排名为rk的数
	int rt=root;
	//这里的写法非常优秀,刚好以rk的相对大小分类
	while(rt){
		//如果查询的rk<=左子树siz,在左子树查询rk
		if(node[node[rt].son[0]].siz>=rk) rt=node[rt].son[0];
		//否则,如果查询的rk在(左子树siz,左子树siz+当前节点cnt]内,结果即为当前节点val
		else if(node[node[rt].son[0]].siz+node[rt].cnt>=rk) return node[rt].val;
		//否则,在右子树中查询rk-(左子树siz+当前节点cnt)
		else{
			rk-=(node[node[rt].son[0]].siz+node[rt].cnt);
			rt=node[rt].son[1];
		}
	}
}

查询前驱:

int getpre(int val){//查询val的前驱
	int rt=root,res;
	while(rt){
		//如果当前节点val<查询的val,返回值更新为当前节点val,在右子树查询
		if(node[rt].val<val){
			res=node[rt].val;
			rt=node[rt].son[1];
		//否则,在左子树查询
		}else rt=node[rt].son[0];
	}
	return res;
}

查询后继:

int getnxt(int val){//查询val的后继
	int rt=root,res;
	while(rt){
		//如果当前节点val>查询的val,返回值更新为当前节点val,在左子树查询
		if(node[rt].val>val){
			res=node[rt].val;
			rt=node[rt].son[0];
		//否则,在右子树查询
		}else rt=node[rt].son[1];
	}
	return res;
}

\(\rm Code:\)

#include <bits/stdc++.h>
using namespace std;
const int N=100010;
const int INF=0x3f3f3f3f;
inline int read(){
	int x=0;bool f=false;char ch=getchar();
	while(!isdigit(ch)){
		if(ch=='-')	f=true;
		ch=getchar();
	}
	while(isdigit(ch)){
		x=(x<<3)+(x<<1)+(ch^48);
		ch=getchar();
	}
	return f?-x:x;
}
struct treap{
	int son[2];
	int val,cnt,siz,dat;
}node[N];
int t,root,tot;
int randdd(){
	static unsigned int random=114514;
	return (random*1919810ull)%2147483647;
}
void pushup(int rt){
	node[rt].siz=node[node[rt].son[0]].siz+node[node[rt].son[1]].siz+node[rt].cnt;
}
int build(int val){
	tot++;
	node[tot].val=val;
	node[tot].cnt=node[tot].siz=1;
	node[tot].son[0]=node[tot].son[1]=0;
	node[tot].dat=rand();
	return tot;
}
void rotate(int &rt,int d){
	int k=node[rt].son[d^1];
	node[rt].son[d^1]=node[k].son[d];
	node[k].son[d]=rt;
	rt=k;
	pushup(node[rt].son[d]);
	pushup(rt);
}
void insert(int &rt,int val){
	if(!rt){
		rt=build(val);
		return;
	}
	node[rt].siz++;
	if(node[rt].val==val) node[rt].cnt++;
	else if(node[rt].val>val){
		insert(node[rt].son[0],val);
		if(node[rt].dat<node[node[rt].son[0]].dat) rotate(rt,1);
	}else{
		insert(node[rt].son[1],val);
		if(node[rt].dat<node[node[rt].son[1]].dat) rotate(rt,0);
	}
	pushup(rt);
}
void init(){
	root=build(-INF);
	insert(root,INF);
}
void del(int &rt,int val){
	if(!rt) return;
	if(node[rt].val==val){
		if(node[rt].cnt>1){
			node[rt].cnt--;
			pushup(rt);
			return;
		}
		if(node[rt].son[0]||node[rt].son[1]){
			if(!node[rt].son[1]||node[node[rt].son[0]].dat>node[node[rt].son[1]].dat){
				rotate(rt,1);
				del(node[rt].son[1],val);
			}else{
				rotate(rt,0);
				del(node[rt].son[0],val);
			}
		}else rt=0;
	}else if(node[rt].val>val) del(node[rt].son[0],val);
	else del(node[rt].son[1],val);
	pushup(rt);
}
int getrk(int val){
	int rt=root,res=0;
	while(rt){
		if(node[rt].val==val) return node[node[rt].son[0]].siz+res+1;
		if(node[rt].val<val){
			res+=node[node[rt].son[0]].siz+node[rt].cnt;
			rt=node[rt].son[1];
		}else rt=node[rt].son[0];
	}
	return res;
}
int getval(int rk){
	int rt=root;
	while(rt){
		if(node[node[rt].son[0]].siz>=rk) rt=node[rt].son[0];
		else if(node[node[rt].son[0]].siz+node[rt].cnt>=rk) return node[rt].val;
		else{
			rk-=(node[node[rt].son[0]].siz+node[rt].cnt);
			rt=node[rt].son[1];
		}
	}
}
int getpre(int val){
	int rt=root,res;
	while(rt){
		if(node[rt].val<val){
			res=node[rt].val;
			rt=node[rt].son[1];
		}else rt=node[rt].son[0];
	}
	return res;
}
int getnxt(int val){
	int rt=root,res;
	while(rt){
		if(node[rt].val>val){
			res=node[rt].val;
			rt=node[rt].son[0];
		}else rt=node[rt].son[1];
	}
	return res;
}
signed main(){
	srand(time(0));
	t=read();
	init();
	while(t--){
		int opt,x;
		opt=read();x=read();
		if(opt==1) insert(root,x);
		if(opt==2) del(root,x);
		if(opt==3) printf("%d\n",getrk(x)-1);
		if(opt==4) printf("%d\n",getval(x+1));
		if(opt==5) printf("%d\n",getpre(x));
		if(opt==6) printf("%d\n",getnxt(x));
	}
	return 0;
}
posted @ 2021-07-19 21:32  Blueqwq  阅读(68)  评论(0编辑  收藏  举报