【洛谷P3369】【模板】普通平衡树
题目
题目链接:https://www.luogu.com.cn/problem/P3369
您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
- 插入 \(x\) 数
- 删除 \(x\) 数(若有多个相同的数,因只删除一个)
- 查询 \(x\) 数的排名(排名定义为比当前数小的数的个数 \(+1\) )
- 查询排名为 \(x\) 的数
- 求 \(x\) 的前驱(前驱定义为小于 \(x\),且最大的数)
- 求 \(x\) 的后继(后继定义为大于 \(x\),且最小的数)
思路
敲了一发 Splay。中途因为不明原因死循环了一个点。
没什么好说的,Splay 入门可以参考这篇 洛谷日报。
代码
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int N=100010,Inf=1e9;
int Q,rt;
struct Splay
{
int tot,fa[N],son[N][2],cnt[N],size[N],val[N];
void pushup(int x)
{
size[x]=size[son[x][0]]+size[son[x][1]]+cnt[x];
}
void build()
{
tot=2; rt=1;
val[1]=-Inf; val[2]=Inf;
fa[2]=1; son[1][1]=2;
cnt[1]=cnt[2]=1;
pushup(2); pushup(1);
}
int pos(int x)
{
return x==son[fa[x]][1];
}
void rotate(int x)
{
int y=fa[x],z=fa[y],k=pos(x),c=son[x][k^1];
fa[x]=z; son[z][pos(y)]=x;
fa[y]=x; son[x][k^1]=y;
fa[c]=y; son[y][k]=c;
pushup(y); pushup(x);
}
void splay(int x,int f=0)
{
while (fa[x]!=f)
{
int y=fa[x],z=fa[y];
if (z!=f)
{
if (pos(x)==pos(y)) rotate(y);
else rotate(x);
}
rotate(x);
}
if (!f) rt=x;
}
int find(int x)
{
int p=rt;
while (son[p][x>val[p]] && val[p]!=x)
p=son[p][x>val[p]];
splay(p);
return p;
}
int pre(int x)
{
int p=find(x);
if (val[p]<x) return p;
p=son[p][0];
while (son[p][1]) p=son[p][1];
splay(p);
return p;
}
int nxt(int x)
{
int p=find(x);
if (val[p]>x) return p;
p=son[p][1];
while (son[p][0]) p=son[p][0];
splay(p);
return p;
}
void ins(int x)
{
int p=rt,f=0;
while (p && val[p]!=x)
{
f=p;
p=son[p][x>val[p]];
}
if (val[p]==x) cnt[p]++;
else
{
p=++tot;
val[p]=x; cnt[p]=size[p]=1;
fa[p]=f; son[f][x>val[f]]=p;
}
splay(p);
}
void del(int x)
{
int y=pre(x),z=nxt(x);
splay(y); splay(z,y);
int p=son[z][0];
if (cnt[p]>1)
{
cnt[p]--;
splay(p);
}
else son[z][0]=0;
pushup(z); pushup(p);
}
int get_rank(int x)
{
int p=find(x);
return size[son[p][0]];
}
int get_val(int k)
{
int p=rt;
while (1)
{
if (size[son[p][0]]>=k) p=son[p][0];
else if (size[son[p][0]]+cnt[p]>=k) break;
else k-=size[son[p][0]]+cnt[p],p=son[p][1];
}
splay(p);
return p;
}
}splay;
int main()
{
scanf("%d",&Q);
splay.build();
while (Q--)
{
int opt,x;
scanf("%d%d",&opt,&x);
if (opt==1) splay.ins(x);
if (opt==2) splay.del(x);
if (opt==3) printf("%d\n",splay.get_rank(x));
if (opt==4) printf("%d\n",splay.val[splay.get_val(x+1)]);
if (opt==5) printf("%d\n",splay.val[splay.pre(x)]);
if (opt==6) printf("%d\n",splay.val[splay.nxt(x)]);
}
return 0;
}