平衡树学习笔记

平衡树

平衡树是一类二叉查找树,因为普通的二叉查找树可能会因为特殊的数据的构造变成链,导致原本应该是 \(\mathcal O(\log n)\) 的查找速度退化成为 \(\mathcal O(n)\),损失大量效率。为了解决这个问题,就有了平衡树这一数据结构。

平衡树,就是对二叉查找树进行一些变形,使得这个二叉查找树尽量的平衡,使得单次操作的时间复杂度尽量靠近于 \(\mathcal O(\log n)\)。下面以 Luogu P3369 为例,对一些常用的平衡树进行一些讲解。

非旋 Treap

非旋 Treap 的中心思想还是与带旋 Treap 是一致的,即引入一个优先级 \(prio\) 的概念,让 BST 同时满足堆和 BST 的性质。

非旋 Treap 需要实现两个操作:splitmerge。所有的操作都会采用这两个操作作为基础实现。

Split

split 操作简单来说,就是按照一个键值 \(key\) 将 Treap 分裂成为两个 Treap,满足一个 Treap 中的所有元素都小于等于 \(key\),另一个 Treap 中的所有元素都大于 \(key\)

void split(int x,int key,int &l,int &r)
{
	if (!x) {l=0,r=0;return;}
	if (val[x]<=key) {l=x,split(rs(x),key,rs(x),r);}
	else {r=x,split(ls(x),key,l,ls(x));}
	maintain(x);
}

Merge

merge 操作就是将两个 Treap 合并成为一个 Treap。需要注意的是这两个 Treap 需要满足一个 Treap 中的所有值都要小于另一个 Treap 中的所有值,不过因为我们用到 merge 的大多数时候都是将 split 出的两半 Treap 给重新粘回去。

实现方式类似左偏树?

int merge(int x,int y)
{
	if (!x || !y) return x+y;
	if (prio[x]<prio[y]) {rs(x)=merge(rs(x),y);maintain(x);return x;}
	else {ls(y)=merge(x,ls(y));maintain(y);return y;}
}

插入/删除

插入删除比较类似,可以采用 split 函数进行实现,会比一般的 BST 的写法要简单一些,不过可能常数会大点。(你都用非旋 Treap 了还担心常数)

void insert(int key)
{
	split(root,key,a,c);
	split(a,key-1,a,b);
	if (b) {++cnt[b];++siz[b];}
	else b=newnode(key);
	root=merge(merge(a,b),c);
}
void remove(int key)
{
	split(root,key,a,c);
	split(a,key-1,a,b);
	if (cnt[b]>1) {--cnt[b];--siz[b];a=merge(a,b);}
	root=merge(a,c);
}

其他 BST 基本操作

int Get_Rank(int key)
{
	split(root,key-1,a,b);
	int ans=siz[a]+1;
	root=merge(a,b);
	return ans;
}
int Get_Val(int kth)
{
	int x=root;
	while (x)
		if (kth<=siz[ls(x)]) x=ls(x);
		else if (kth<=siz[ls(x)]+cnt[x]) break;
		else kth-=siz[ls(x)]+cnt[x],x=rs(x);
	return val[x];
}
int find_max(int x) {return rs(x)?find_max(rs(x)):x;}
int find_min(int x) {return ls(x)?find_min(ls(x)):x;}
int Get_Pre(int key)
{
	split(root,key-1,a,b);
	int ans=val[find_max(a)];
	root=merge(a,b);
	return ans;
}
int Get_Nxt(int key)
{
	split(root,key,a,b);
	int ans=val[find_min(b)];
	root=merge(a,b);
	return ans;
}

完整代码

#include<bits/stdc++.h>
#define mem(a,b) memset(a,b,sizeof a)
//#define int long long
using namespace std;
void read(auto &k)
{
	k=0;auto flag=1;char b=getchar();
	while (!isdigit(b)) {flag=(b=='-')?-1:1;b=getchar();}
	while (isdigit(b)) {k=k*10+b-48;b=getchar();}
	k*=flag;
}
void write(auto k) {if (k<0) {putchar('-'),write(-k);return;}if (k>9) write(k/10);putchar(k%10+48);}
void writewith(auto k,char c) {write(k);putchar(c);}
namespace FHQ_Treap{
	const int _SIZE=1e5;
	#define ls(_) ch[_][0]
	#define rs(_) ch[_][1]
	int ch[_SIZE+5][2],val[_SIZE+5],siz[_SIZE+5],prio[_SIZE+5],cnt[_SIZE+5];
	int tot,root,a,b,c;
	void maintain(int x) {siz[x]=siz[ls(x)]+siz[rs(x)]+cnt[x];}
	int newnode(int key) {++tot,val[tot]=key,prio[tot]=rand(),cnt[tot]=siz[tot]=1;return tot;}
	void split(int x,int key,int &l,int &r)
	{
		if (!x) {l=0,r=0;return;}
		if (val[x]<=key) {l=x,split(rs(x),key,rs(x),r);}
		else {r=x,split(ls(x),key,l,ls(x));}
		maintain(x);
	}
	int merge(int x,int y)
	{
		if (!x || !y) return x+y;
		if (prio[x]<prio[y]) {rs(x)=merge(rs(x),y);maintain(x);return x;}
		else {ls(y)=merge(x,ls(y));maintain(y);return y;}
	}
	void insert(int key)
	{
		split(root,key,a,c);
		split(a,key-1,a,b);
		if (b) {++cnt[b];++siz[b];}
		else b=newnode(key);
		root=merge(merge(a,b),c);
	}
	void remove(int key)
	{
		split(root,key,a,c);
		split(a,key-1,a,b);
		if (cnt[b]>1) {--cnt[b];--siz[b];a=merge(a,b);}
		root=merge(a,c);
	}
	int Get_Rank(int key)
	{
		split(root,key-1,a,b);
		int ans=siz[a]+1;
		root=merge(a,b);
		return ans;
	}
	int Get_Val(int kth)
	{
		int x=root;
		while (x)
			if (kth<=siz[ls(x)]) x=ls(x);
			else if (kth<=siz[ls(x)]+cnt[x]) break;
			else kth-=siz[ls(x)]+cnt[x],x=rs(x);
		return val[x];
	}
	int find_max(int x) {return rs(x)?find_max(rs(x)):x;}
	int find_min(int x) {return ls(x)?find_min(ls(x)):x;}
	int Get_Pre(int key)
	{
		split(root,key-1,a,b);
		int ans=val[find_max(a)];
		root=merge(a,b);
		return ans;
	}
	int Get_Nxt(int key)
	{
		split(root,key,a,b);
		int ans=val[find_min(b)];
		root=merge(a,b);
		return ans;
	}
	#undef ls
	#undef rs
} using namespace FHQ_Treap;
int n;
signed main()
{
	read(n);
	for (int i=1;i<=n;i++)
	{
		int opt,x;read(opt),read(x);
		switch(opt) 
		{
			case 1: insert(x);break;
			case 2: remove(x);break;
			case 3: writewith(Get_Rank(x),'\n');break;
			case 4: writewith(Get_Val(x),'\n');break;
			case 5: writewith(Get_Pre(x),'\n');break;
			case 6: writewith(Get_Nxt(x),'\n');
		}
	}
	return 0;
}

Splay

Splay 是一种神奇的平衡树,给一个关键词就是:转转转

虽然时间复杂度是 \(\mathcal O(n\log n)\) 的,但是常数很大,可能有 \(8\) 左右,所以不算一个比较高效的平衡树。

Splay 首先需要记录每个节点的值,值出现次数,节点左右儿子,父亲节点,节点子树大小(有什么用之后会了解)。

定义

namespace Splay{
	#define lc(x) ch[x][0]
	#define rc(x) ch[x][1]
	const int _SIZE=1e5;
	int root,tot;
	int val[_SIZE+5],cnt[_SIZE+5],ch[_SIZE+5][2],sz[_SIZE+5],fa[_SIZE+5];
    #undef lc
	#undef rc
}

这段代码进行了宏定义,可以比较好的简化代码,提升可读性。因为是定义在命名空间中,所以建议在命名空间结束的时候取消定义,养成一个好习惯,方便以后可能的更大的代码项目。

三个基础操作

void maintain(int x) {sz[x]=sz[lc(x)]+sz[rc(x)]+cnt[x];}//更新子树大小
bool get(int x) {return x==rc(fa[x]);}//判断x为哪个子树
void clear(int x) {val[x]=cnt[x]=lc(x)=rc(x)=sz[x]=fa[x]=0;}//清除x节点

这三个函数的作用应该是显而易见的了。这里不做过多解释。

旋转

Splay 的一个基本操作就是旋转 rotaterotate 操作会将节点 x 向其父节点 y 旋转。具体操作(以 xy 的左儿子为例):将 x 设为 y 的父节点,然后将 x 的右儿子给设置为 y 的左儿子。不难发现,这样的旋转操作不会破坏 BST 的平衡性。

void rotate(int x,int &rt=root)
{
	int y=fa[x],z=fa[y],chk=get(x);//chk用来确定x是在哪一个子树
	ch[y][chk]=ch[x][chk^1];
	if (ch[x][chk^1]) fa[ch[x][chk^1]]=y;//x的儿子与y连边
	ch[x][chk^1]=y,fa[y]=x,fa[x]=z;//x与y父子关系反转
	if (z) ch[z][y==rc(z)]=x;//如果y有父节点z需要将z的儿子改到x
	else rt=x;//如果没有就将这个子树的根改为x
	maintain(y),maintain(x);//更新size
}

Splay

这是 Splay 这种平衡树的最关键的操作,是其时间复杂度的保证。具体的做法就是将某一个节点 x 给旋转到整棵 BST 的一个节点下(一般为根节点)。

void splay(int x,int &rt=root)
{
	int y=fa[x];
	for (;x!=rt;rotate(x,rt),y=fa[x])//没到就一直转
		if (y!=rt) rotate(get(x)==get(y)?y:x,rt);//如果y不是根节点就需要判断x,y是否是在一条链上的,如果是就先旋转y再旋转x,否则旋转两次x;如果y是根节点就只需要旋转一次x
	rt=x;//更新新的子树的根节点
}

需要注意的是,在 Splay 的所有的可能更改平衡树结构的操作时,都需要将新更改的节点 splay 到根节点,否则将无法保证时间复杂度的正确性。

插入

相比于前面的两个函数,接下来的操作就是普通的 BST 也支持的东西了,实现也会比较简单一些了,就直接给出代码了(记住 splay)。

void insert(int k)
{
	if (!root)//树是空的
	{
		root=++tot,cnt[tot]++,val[tot]=k;
		return maintain(root);
	}
	int cur=root,f=0;
	while (1)//非递归实现
	{
		if (k==val[cur])//存在节点
		{
			cnt[cur]++;
			maintain(cur),maintain(f);
			return splay(cur);//记得splay
		}
		f=cur,cur=ch[cur][k>val[cur]];
		if (!cur)//新建节点
		{
			cnt[++tot]++,val[tot]=k;
			fa[tot]=f,ch[f][k>val[f]]=tot;
			maintain(tot),maintain(f);
			return splay(tot);//splay
		}
	}
}

根据 Val 查询排名

int rk(int k)
{
	int res=0,cur=root;//res用于存储目前的排名
	while (1)
	{
		if (k<val[cur]) cur=lc(cur);
		else
		{
			res+=sz[lc(cur)];//不在左子树,就将左子树的全部节点个数统计入排名
			if (k==val[cur]) {splay(cur);return res+1;}
			res+=cnt[cur],cur=rc(cur);
		}
	}
}

根据排名查询 Val

int kth(int k)
{
	int cur=root;
	while (1)
	{
		if (lc(cur) && k<=sz[lc(cur)]) cur=lc(cur);
		else
		{
			k-=sz[lc(cur)]+cnt[cur];//直接减,如果减成负数就证明是当前节点
			if (k<=0) {splay(cur);return val[cur];}
			cur=rc(cur);
		}
	}
}

查询前驱

查询 x 前驱的操作可以变成插入 x,然后将 x splay 到根节点,此时左子树中的最大值就是 x 的前驱,最后再将 x 删除即可。

这里给出查找根节点左子树最大值的函数。

int pre()
{
	int cur=lc(root);
	if (!cur) return cur;
	while (rc(cur)) cur=rc(cur);//只要有右儿子就一直向右下走
	splay(cur);//将cur旋转到根节点
	return cur;//返回节点编号
}

查询后继

与查询前驱基本一致,插入 x,在根节点右子树查找最小值,删除 x

int nxt()
{
	int cur=rc(root);
	if (!cur) return cur;
	while (lc(cur)) cur=lc(cur);
	splay(cur);
	return cur;
}

删除操作

假设删除的数为 x,那么先将 x splay 至根节点,然后删除该数。如果 x 节点的 cnt 值被减为了 \(0\),那么就删除根节点,合并根节点的左右子树(此时 x 已经被 splay 到根节点了)。

假设合并的两棵平衡树为 \(A,B\),那么分情况进行讨论(需要注意,此处的 \(A,B\) 是需要满足 \(\max\{A_i\}<\min\{B_j\}\),即 \(A\) 的最大值小于 \(B\) 的最小值):

  1. 如果 \(A,B\) 二者中有一者为空,就将非空者作为新的平衡树。
  2. 如果 \(A,B\) 二者均非空,那么将 \(A\) 树的最大值 splay 到根节点(此时显然根节点是没有右子树的),然后将 \(A\) 树根节点的右子树接到 \(B\) 树的根节点上,并更新节点信息。
void remove(int k)
{
	rk(k);
	if (cnt[root]>1) {cnt[root]--;return maintain(root);}
	if (!lc(root) && !rc(root)) {clear(root);root=0;return;}//删后为空树
	if (!lc(root))//没有左子树
	{
		int cur=root;root=rc(root),fa[root]=0;
		return clear(cur);
	}
	if (!rc(root))//没有右子树
	{
		int cur=root;root=lc(root),fa[root]=0;
		return clear(cur);
	}
	int cur=root,x=pre();//利用之前写的pre函数获取到左子树的最大值
	fa[rc(cur)]=x,rc(x)=rc(cur);//连接A B根节点
	clear(cur),maintain(root);//清楚原来根节点的信息,更新新根节点的信息
}

总代码

#include<bits/stdc++.h>
#define mem(a,b) memset(a,b,sizeof a)
//#define int long long
using namespace std;
template<typename T> void read(T &k)
{
	k=0;T flag=1;char b=getchar();
	while (!isdigit(b)) {flag=(b=='-')?-1:1;b=getchar();}
	while (isdigit(b)) {k=k*10+b-48;b=getchar();}
	k*=flag;
}
template<typename T> void write(T k) {if (k<0) {putchar('-'),write(-k);return;}if (k>9) write(k/10);putchar(k%10+48);}
template<typename T> void writewith(T k,char c) {write(k);putchar(c);}
namespace Splay{
	#define lc(x) ch[x][0]
	#define rc(x) ch[x][1]
	const int _SIZE=1e5;
	int root,tot;
	int val[_SIZE+5],cnt[_SIZE+5],ch[_SIZE+5][2],sz[_SIZE+5],fa[_SIZE+5];
	void maintain(int x) {sz[x]=sz[lc(x)]+sz[rc(x)]+cnt[x];}
	bool get(int x) {return x==rc(fa[x]);}
	void clear(int x) {val[x]=cnt[x]=lc(x)=rc(x)=sz[x]=fa[x]=0;}
	void rotate(int x,int &rt=root)
	{
		int y=fa[x],z=fa[y],chk=get(x);
		ch[y][chk]=ch[x][chk^1];
		if (ch[x][chk^1]) fa[ch[x][chk^1]]=y;
		ch[x][chk^1]=y,fa[y]=x,fa[x]=z;
		if (z) ch[z][y==rc(z)]=x;
		else rt=x;
		maintain(y),maintain(x);
	}
	void splay(int x,int &rt=root)
	{
		int y=fa[x];
		for (;x!=rt;rotate(x,rt),y=fa[x])
			if (y!=rt) rotate(get(x)==get(y)?y:x,rt);
		rt=x;
	}
	void insert(int k)
	{
		if (!root)
		{
			root=++tot,cnt[tot]++,val[tot]=k;
			return maintain(root);
		}
		int cur=root,f=0;
		while (1)
		{
			if (k==val[cur])
			{
				cnt[cur]++;
				maintain(cur),maintain(f);
				return splay(cur);
			}
			f=cur,cur=ch[cur][k>val[cur]];
			if (!cur)
			{
				cnt[++tot]++,val[tot]=k;
				fa[tot]=f,ch[f][k>val[f]]=tot;
				maintain(tot),maintain(f);
				return splay(tot);
			}
		}
	}
	int rk(int k)
	{
		int res=0,cur=root;
		while (1)
		{
			if (k<val[cur]) cur=lc(cur);
			else
			{
				res+=sz[lc(cur)];
				if (k==val[cur]) {splay(cur);return res+1;}
				res+=cnt[cur],cur=rc(cur);
			}
		}
	}
	int kth(int k)
	{
		int cur=root;
		while (1)
		{
			if (lc(cur) && k<=sz[lc(cur)]) cur=lc(cur);
			else
			{
				k-=sz[lc(cur)]+cnt[cur];
				if (k<=0) {splay(cur);return val[cur];}
				cur=rc(cur);
			}
		}
	}
	int pre()
	{
		int cur=lc(root);
		if (!cur) return cur;
		while (rc(cur)) cur=rc(cur);
		splay(cur);
		return cur;
	}
	int nxt()
	{
		int cur=rc(root);
		if (!cur) return cur;
		while (lc(cur)) cur=lc(cur);
		splay(cur);
		return cur;
	}
	void remove(int k)
	{
		rk(k);
		if (cnt[root]>1) {cnt[root]--;return maintain(root);}
		if (!lc(root) && !rc(root)) {clear(root);root=0;return;}
		if (!lc(root))
		{
			int cur=root;root=rc(root),fa[root]=0;
			return clear(cur);
		}
		if (!rc(root))
		{
			int cur=root;root=lc(root),fa[root]=0;
			return clear(cur);
		}
		int cur=root,x=pre();
		fa[rc(cur)]=x,rc(x)=rc(cur);
		clear(cur),maintain(root);
	}
	#undef lc
	#undef rc
}using namespace Splay;
int n;
signed main()
{
	read(n);
	for (int i=1;i<=n;i++)
	{
		int opt,x;read(opt),read(x);
		if (opt==1) insert(x);
		if (opt==2) remove(x);
		if (opt==3) writewith(rk(x),'\n');
		if (opt==4) writewith(kth(x),'\n');
		if (opt==5) insert(x),writewith(val[pre()],'\n'),remove(x);
		if (opt==6) insert(x),writewith(val[nxt()],'\n'),remove(x);
	}
	return 0;
}

替罪羊树

替罪羊树是平衡树中效率极其优秀的,常数较小(可能是仅次于红黑树的),缺点是不如 Splay 和非旋 Treap 那么通用。

定义

先给出替罪羊树的变量声明。

namespace SGT{
	const int _SIZE=1e5;
	const double alpha=0.7;//一个常数,后文会提到
	int tot,root,lc[_SIZE+5],rc[_SIZE+5],val[_SIZE+5];//按顺序,节点个数,根节点,左右儿子,节点权值
	int cnt[_SIZE+5],sz[_SIZE+5],szv[_SIZE+5],szd[_SIZE+5];//节点权值个数,子树大小(每个节点记1次),子树大小(每个节点记cnt次),非空子树大小(非空节点记1次)
}

信息更新

根据定义,可以很简单的得出更新方式。

void maintain(int x)
{
	sz[x]=sz[lc[x]]+sz[rc[x]]+1;
	szv[x]=szv[lc[x]]+szv[rc[x]]+cnt[x];
	szd[x]=szd[lc[x]]+szd[rc[x]]+(cnt[x]!=0);
}

判断是否需要重构

替罪羊树引入了一个常数 \(\alpha\)(一般为 \(0.6\sim 0.7\),通常取 \(0.7\)),当某个节点 x 的某一子树 y 大小占到了 x\(\alpha\),那么就将这棵子树重构。并且,如果一棵子树的空节点(空节点定义为一个 cnt=0 的有权值 val 的节点)占到了该子树的 \(\alpha\),那么也将这棵子树重构。

bool canRbd(int x)
{
	return cnt[x] && (alpha*sz[x]<=(double)max(sz[lc[x]],sz[rc[x]])) || sz[x]*alpha>=(double)szd[x];
}

拍扁重构

这是替罪羊树最核心的操作,时间复杂度是由这个操作来保证的。当判断了某个子树是否需要重构后,就需要进行重构操作。

将操作分为两步:拍扁、重构。

拍扁

拍扁就是将某个子树按照中序遍历的顺序存储到一个 vector 中(一般用数组模拟,而不用 vector,否则就可能失去替罪羊树小常数的优势),一般就用普通的二叉树的中序遍历的方式就行。

int ldr[_SIZE+5];//模拟的vector
void Rbd_flatten(int &ldc,int x)//ldc是vector的尾下标
{
	if (!x) return;//不存在x这个节点
	Rbd_flatten(ldc,lc[x]);
	if (cnt[x]) ldr[ldc++]=x;//只要该节点不是空节点就加入vector中
	Rbd_flatten(ldc,rc[x]);
}

重构

直接用二分的方式递归 vector 建树即可。

int Rbd_build(int l,int r)//注意有返回值,返回根节点的新编号,此处l,r是前闭后开
{
	int mid=l+r>>1;
	if (l>=r) return 0;
	lc[ldr[mid]]=Rbd_build(l,mid);//重构左子树
	rc[ldr[mid]]=Rbd_build(mid+1,r);//右子树
	maintain(ldr[mid]);//更新节点信息
	return ldr[mid];
}

总代码:

int ldr[_SIZE+5];
void Rbd_flatten(int &ldc,int x)
{
	if (!x) return;
	Rbd_flatten(ldc,lc[x]);
	if (cnt[x]) ldr[ldc++]=x;
	Rbd_flatten(ldc,rc[x]);
}
int Rbd_build(int l,int r)
{
	int mid=l+r>>1;
	if (l>=r) return 0;
	lc[ldr[mid]]=Rbd_build(l,mid);
	rc[ldr[mid]]=Rbd_build(mid+1,r);
	maintain(ldr[mid]);
	return ldr[mid];
}
void Rebuild(int &x)
{
	int ldc=0;
	Rbd_flatten(ldc,x);
	x=Rbd_build(0,ldc);
}

插入

与普通 BST 相同,采用递归实现。

void insert(int &k,int p)
{
	if (!k)//
	{
		k=++tot;
		if (!root) root=1;
		val[k]=p,lc[k]=rc[k]=0;
		sz[k]=szv[k]=szd[k]=cnt[k]=1;
	}
	else
	{
		if (val[k]==p) cnt[k]++;
		else if (val[k]<p) insert(rc[k],p);
		else insert(lc[k],p);
		maintain(k);
		if (canRbd(k)) Rebuild(k);//每次对BST的结构更改的时候都需要判断是否需要重构
	}
}

删除

替罪羊树的删除是采用类似懒删除的方式,只将对应节点的 cnt--,而不判断删除后是否成为了空节点,当空节点的数目很多的时候才会使用拍扁重构来清除这些空节点。

void remove(int &k,int p)
{
	if (!k) return;//删除的节点不存在,忽略
	if (val[k]==p) cnt[k]--;
	else if (val[k]<p) remove(rc[k],p);
	else remove(lc[k],p);
	maintain(k);
	if (canRbd(k)) Rebuild(k);//判断是否应该重构
}

Upper_bound 和 Upper_greater

uprbd 函数用于找到最小的大于某个值的节点的位置,uprgr 函数用于找到最大的小于某个值的节点的位置,实现方式不讲(BST 基本操作)。

int uprbd(int k,int p)
{
	if (!k) return 1;
	if (val[k]==p && cnt[k]) return szv[lc[k]]+cnt[k]+1;
	else if (p<val[k]) return uprbd(lc[k],p);
	else return szv[lc[k]]+cnt[k]+uprbd(rc[k],p);
}
int uprgr(int k,int p)
{
	if (!k) return 0;
	if (val[k]==p && cnt[k]) return szv[lc[k]];
	else if (p<val[k]) return uprgr(lc[k],p);
	else return szv[lc[k]]+cnt[k]+uprgr(rc[k],p);
}

排名与权值相互查询

查询排名可以用上面的 uprgr 函数直接得到。

int getRank(int x) {return uprgr(root,x)+1;}

查询权值也很简单。

int getVal(int k,int p)
{
	if (!k) return 0;
	if (szv[lc[k]]<p && p<=szv[lc[k]]+cnt[k]) return val[k];
	else if (szv[lc[k]]+cnt[k]<p) return getVal(rc[k],p-szv[lc[k]]-cnt[k]);
	else return getVal(lc[k],p);
}

查询前驱后继

直接用 getVal 函数与 uprbduprgr 函数组合即可。

int getPre(int k,int p) {return getVal(k,uprgr(k,p));}
int getNxt(int k,int p) {return getVal(k,uprbd(k,p));}

总代码

因为 Dev-C++ 不好用中文写注释所以就干脆写英文了(逃

#include<bits/stdc++.h>
#define mem(a,b) memset(a,b,sizeof a)
//#define int long long
using namespace std;
template<typename T> void read(T &k)
{
	k=0;T flag=1;char b=getchar();
	while (!isdigit(b)) {flag=(b=='-')?-1:1;b=getchar();}
	while (isdigit(b)) {k=k*10+b-48;b=getchar();}
	k*=flag;
}
template<typename T> void write(T k) {if (k<0) {putchar('-'),write(-k);return;}if (k>9) write(k/10);putchar(k%10+48);}
template<typename T> void writewith(T k,char c) {write(k);putchar(c);}
namespace SGT{
	const int _SIZE=1e5;
	const double alpha=0.7;
	int tot,root,lc[_SIZE+5],rc[_SIZE+5],val[_SIZE+5];
	int cnt[_SIZE+5],sz[_SIZE+5],szv[_SIZE+5],szd[_SIZE+5];
	void maintain(int x)
	{
		sz[x]=sz[lc[x]]+sz[rc[x]]+1;//size of tree, each node count as 1
		szv[x]=szv[lc[x]]+szv[rc[x]]+cnt[x];//each node count as its cnt
		szd[x]=szd[lc[x]]+szd[rc[x]]+(cnt[x]!=0);//each node count as 1,deleted node not counted
	}
	bool canRbd(int x)//can rebuild
	{
		return cnt[x] && (alpha*sz[x]<=(double)max(sz[lc[x]],sz[rc[x]])) || sz[x]*alpha>=(double)szd[x];
	}
	int ldr[_SIZE+5];
	void Rbd_flatten(int &ldc,int x)//flatten function in rebuild
	{//ldc->tail of ldr(temp vector for flatten)
		if (!x) return;//node x not exist
		Rbd_flatten(ldc,lc[x]);//into left son
		if (cnt[x]) ldr[ldc++]=x;//node x isnt empty
		Rbd_flatten(ldc,rc[x]);//right son
	}
	int Rbd_build(int l,int r)//returns the new number of node
	{
		int mid=l+r>>1;
		if (l>=r) return 0;
		lc[ldr[mid]]=Rbd_build(l,mid);//get left son node, and rebuild
		rc[ldr[mid]]=Rbd_build(mid+1,r);
		maintain(ldr[mid]);
		return ldr[mid];
	}
	void Rebuild(int &x)
	{
		int ldc=0;
		Rbd_flatten(ldc,x);//flatten into vector
		x=Rbd_build(0,ldc);//rebuild vector into BST
	}
	void insert(int &k,int p)
	{
		if (!k)//p not in BST
		{
			k=++tot;//newnode
			if (!root) root=1;//empty tree
			val[k]=p,lc[k]=rc[k]=0;//new node
			sz[k]=szv[k]=szd[k]=cnt[k]=1;//init
		}
		else
		{
			if (val[k]==p) cnt[k]++;//current node is p,add 1 to cnt
			else if (val[k]<p) insert(rc[k],p);//p in right son
			else insert(lc[k],p);// p in left son
			maintain(k);//update
			if (canRbd(k)) Rebuild(k);//rebuild
		}
	}
	void remove(int &k,int p)
	{
		if (!k) return;//k not in BST
		if (val[k]==p) cnt[k]--;//current node is p,del it
		else if (val[k]<p) remove(rc[k],p);//p in right son
		else remove(lc[k],p);//p in left son
		maintain(k);//update
		if (canRbd(k)) Rebuild(k);//rebuild
	}
	int uprbd(int k,int p)//same as upper_bound, finds the smallest elements greater than p
	{
		if (!k) return 1;//p is smaller than any one
		if (val[k]==p && cnt[k]) return szv[lc[k]]+cnt[k]+1;//p equals to current node, the ans is first one in right son
		else if (p<val[k]) return uprbd(lc[k],p);//p in left son
		else return szv[lc[k]]+cnt[k]+uprbd(rc[k],p);//p in right son
	}
	int uprgr(int k,int p)//finds the largerst elements smaller than p
	{
		if (!k) return 0;
		if (val[k]==p && cnt[k]) return szv[lc[k]];
		else if (p<val[k]) return uprgr(lc[k],p);
		else return szv[lc[k]]+cnt[k]+uprgr(rc[k],p);
	}
	int getRank(int x) {return uprgr(root,x)+1;}//rank is the rank of upper_greater(p) + 1
	int getVal(int k,int p) // get value by rank
	{
		if (!k) return 0;
		if (szv[lc[k]]<p && p<=szv[lc[k]]+cnt[k]) return val[k];
		else if (szv[lc[k]]+cnt[k]<p) return getVal(rc[k],p-szv[lc[k]]-cnt[k]);
		else return getVal(lc[k],p);
	}
	int getPre(int k,int p) {return getVal(k,uprgr(k,p));}
	int getNxt(int k,int p) {return getVal(k,uprbd(k,p));}
	void print(int x)
	{
		if (lc[x]) print(lc[x]);
		if (cnt[x]) writewith(val[x],' ');
		if (rc[x]) print(rc[x]);
	}
} using namespace SGT;
int n;
signed main()
{
	read(n);
	for (int i=1;i<=n;i++)
	{
		int opt,x;read(opt),read(x);
		if (opt==1) insert(root,x);
		if (opt==2) remove(root,x);
		if (opt==3) writewith(getRank(x),'\n');
		if (opt==4) writewith(getVal(root,x),'\n');
		if (opt==5) writewith(getPre(root,x),'\n');
		if (opt==6) writewith(getNxt(root,x),'\n');
	}
	return 0;
}
posted @ 2022-09-24 14:43  Hanx16Msgr  阅读(16)  评论(0编辑  收藏  举报