Red-Black Tree

https://www.cnblogs.com/skywang12345/p/3245399.html

#include<bits/stdc++.h>
using namespace std;
enum Col{red,black};
struct node{
	node *fa,*lson,*rson;
	int val;
	Col col;
	int sz;
	node(int _val=0,node *tmp=nullptr):fa(tmp),lson(tmp),rson(tmp),val(_val),col(black),sz(0){}
}tmp;
class Red_Black_Tree{
private:
	node *root,*nodeNull;
	void Clear(node *x){
		if(x==nodeNull)return;
		Clear(x->lson),Clear(x->rson);
		delete x;
	}
	void Clear_nodeNull(){
		nodeNull->fa=nodeNull->lson=nodeNull->rson=nodeNull;
		nodeNull->val=0,nodeNull->col=black,nodeNull->sz=0;
	}
public:
	Red_Black_Tree(){
		nodeNull=&tmp;
		Clear_nodeNull();
		root=nodeNull;
	}
	~Red_Black_Tree(){Clear(root);}
private:
	void pushup(node *x){
		if(x==nodeNull)return;
		x->sz=x->lson->sz+x->rson->sz+1;
	}
	void R_Rotate(node *x){
		node *y=x->lson;
		node *f=x->fa;
		y->fa=f;
		if(f==nodeNull)root=y;
		else f->lson==x?f->lson=y:f->rson=y;
		x->lson=y->rson;
		if(y->rson!=nodeNull)y->rson->fa=x;
		y->rson=x,x->fa=y;
		pushup(x),pushup(y);
	}
	void L_Rotate(node *x){
		node *y=x->rson;
		node *f=x->fa;
		y->fa=f;
		if(f==nodeNull)root=y;
		else f->lson==x?f->lson=y:f->rson=y;
		x->rson=y->lson;
		if(y->lson!=nodeNull)y->lson->fa=x;
		y->lson=x,x->fa=y;
		pushup(x),pushup(y);
	}
public:
	void Check_DFS(node *x){
		if(x==nodeNull)return;
		// if(x->sz!=x->lson->sz+x->rson->sz+1){
		// 	cout<<"!!!!!!!!!!!!!!"<<" "<<x->sz<<" "<<x->lson->sz<<" "<<x->rson->sz<<endl;
		// 	cout<<nodeNull->sz<<" "<<nodeNull->lson->sz<<" "<<nodeNull->rson->sz<<endl;
		// 	exit(0);
		// }
		// printf("%d %d %d    %s   fa:%d  sz:%d\n",x->val,x->lson->val,x->rson->val,x->col==black?"Black":"Red",x->fa->val,x->sz);
		Check_DFS(x->lson),Check_DFS(x->rson);
	}
	void Check(){
		cout<<"!!!!!!"<<endl;
		// printf("%d %d %d    %s   fa:%d  sz:%d\n",nodeNull->val,nodeNull->lson->val,nodeNull->rson->val,nodeNull->col==black?"Black":"Red",nodeNull->fa->val,nodeNull->sz);
		Check_DFS(root);
	}
	void Insert(int val){
		node *x=root,*y=nodeNull;
		while(x!=nodeNull){
			x->sz++;
			y=x;
			x=(val<x->val)?x->lson:x->rson;
		}
		node *z=new node(val,nodeNull);
		z->sz=1;
		if(y==nodeNull){
			root=z;
			z->col=black;
		}
		else {
			z->fa=y,val<y->val?y->lson=z:y->rson=z;
			z->col=red;
		}
		Insert_Fixup(z);//消除相连的两个红色节点以维护红黑树的性质
	}
	void Insert_Fixup(node *x){
		while(x->fa->col==red){
			if(x->fa==x->fa->fa->lson){
				// cout<<"$1"<<endl;
				node *y=x->fa->fa->rson;
				if(y->col==red){//直接换色,不旋转
					x->fa->col=black,y->col=black;
					x->fa->fa->col=red;
					x=x->fa->fa;
					//往上跳,继续进入循环
				}
				else { 
					if(x==x->fa->rson){//LR
						x=x->fa;
						L_Rotate(x);//此次旋转把LR变成了LL,且根到底下任意一点x的路径上的黑点数不变
					}
					//LL
					x->fa->col=black,x->fa->fa->col=red;
					R_Rotate(x->fa->fa);//进行一次旋转,并换色
					//不管是LL还是LR,进行一次后红黑树的性质就以完全满足,接下来就直接break掉了
				}
			}
			else {//swap(l,r)
				// cout<<"$2"<<endl;
				node *y=x->fa->fa->lson;
				if(y->col==red){//直接换色,不旋转
					x->fa->col=black,y->col=black;
					x->fa->fa->col=red;
					x=x->fa->fa;
					//往上跳,继续进入循环
				}
				else {
					if(x==x->fa->lson){//RL
						x=x->fa;
						R_Rotate(x);//此次旋转把RL变成了RR,且根到底下任意一点x的路径上的黑点数不变
					}
					//RR
					x->fa->col=black,x->fa->fa->col=red;
					L_Rotate(x->fa->fa);//进行一次旋转,并换色
					//不管是RR还是RL,进行一次后红黑树的性质就以完全满足,接下来就直接break掉了
				}
			}
		}
		root->col=black;
	}
	node* successor(node *x){
		x=x->rson;
		while(x->lson!=nodeNull)x=x->lson;
		return x;
	}
	void update_sz(node *x){
		if(x==nodeNull)return;
		while(x!=nodeNull)x->sz--,x=x->fa;
	}
	void Delete(node *x){
		node *y,*z;
		if(x->lson==nodeNull||x->rson==nodeNull)y=x;
		else y=successor(x),x->val=y->val;//改为删x的后继节点
		//y为此时要删的点,保证y至多只有一个儿子
		update_sz(y);
		z=(y->lson!=nodeNull)?y->lson:y->rson;
		z->fa=y->fa;
		if(y->fa==nodeNull)root=z;
		else (y->fa->lson==y)?y->fa->lson=z:y->fa->rson=z;
		if(y->col==black)Delete_Fixup(z);
		delete y;
	}
	void Delete_Fixup(node *x){
		//先给x额外添加一个黑的属性 此时x为 黑+黑,通过不断转移直至 x为 黑+红 or x为根,然后只需把x改为 黑 即可
		while(x!=root&&x->col==black){
			if(x->fa->lson==x){
				node *y=x->fa->rson;
				if(y->col==red){//Case1
					y->col=black,x->fa->col=red;
					L_Rotate(x->fa);
					y=x->fa->rson;
				}
				//经过Case1,一定会变成Case2,3,4
				if(y->lson->col==black&&y->rson->col==black){//Case2
					y->col=red;
					x=x->fa;
				}
				else {
					if(y->rson->col==black){//Case3
						y->lson->col=black,y->col=red;
						R_Rotate(y);
						y=x->fa->rson;
					}
					//经过Case3,一定会变成Case4
					//以下是Case4
					y->col=x->fa->col,x->fa->col=black,y->rson->col=black;
					L_Rotate(x->fa);
					x=root;
				}
			}
			else {//swap(l,r)
				node *y=x->fa->lson;
				if(y->col==red){//Case1
					y->col=black,x->fa->col=red;
					R_Rotate(x->fa);
					y=x->fa->lson;
				}
				//经过Case1,一定会变成Case2,3,4
				if(y->rson->col==black&&y->lson->col==black){//Case2
					y->col=red;
					x=x->fa;
				}
				else {
					if(y->lson->col==black){//Case3
						y->rson->col=black,y->col=red;
						L_Rotate(y);
						y=x->fa->lson;
					}
					//经过Case3,一定会变成Case4
					//以下是Case4
					y->col=x->fa->col,x->fa->col=black,y->lson->col=black;
					R_Rotate(x->fa);
					x=root;
				}
			}
		}
		x->col=black;
	}
	node* Find(int val){
		node *x=root;
		while(x!=nodeNull){
			if(val==x->val)return x;
			else if(val<x->val)x=x->lson;
			else x=x->rson;
		}
		return nodeNull;
	}
	int Find_rank(int val){
		node *x=root;
		int ans=0;
		while(x!=nodeNull){
			if(val>x->val)ans+=x->lson->sz+1,x=x->rson;
			else x=x->lson;
		}
		return ans+1;
	}
	int Find_val(int rnk){
		node *x=root;
		while(x!=nodeNull){
			if(rnk==x->lson->sz+1)return x->val;
			else if(rnk<=x->lson->sz)x=x->lson;
			else rnk-=(x->lson->sz+1),x=x->rson;
		}
		return -1;
	}
	int Find_pre(int val){
		node *x=root;
		int ans=-1;
		while(x!=nodeNull){
			if(val>x->val)ans=x->val,x=x->rson;
			else x=x->lson;
		}
		return ans;
	}
	int Find_next(int val){
		node *x=root;
		int ans=-1;
		while(x!=nodeNull){
			if(val<x->val)ans=x->val,x=x->lson;
			else x=x->rson;
		}
		return ans;
	}
	void Delete(int val){
		node *x=Find(val);
		Delete(x);
	}
}T;
int main(){
	int n,op,x;
	scanf("%d",&n);
	for(int i=1;i<=n;i++){
		scanf("%d%d",&op,&x);
		if(op==1)T.Insert(x);
		if(op==2)T.Delete(x);
		if(op==3)printf("%d\n",T.Find_rank(x));
		if(op==4)printf("%d\n",T.Find_val(x));
		if(op==5)printf("%d\n",T.Find_pre(x));
		if(op==6)printf("%d\n",T.Find_next(x));

		// T.Check();
	}
	return 0;
}
}
posted @ 2022-03-22 12:13  zhongzero  阅读(23)  评论(0编辑  收藏  举报