[luogu] P3369 【模板】普通平衡树(splay)
P3369 【模板】普通平衡树
题目描述
您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
- 插入 \(x\) 数
- 删除 \(x\) 数(若有多个相同的数,因只删除一个)
- 查询 \(x\) 数的排名(排名定义为比当前数小的数的个数 \(+1\) 。若有多个相同的数,因输出最小的排名)
- 查询排名为 \(x\) 的数
- 求 \(x\) 的前驱(前驱定义为小于 \(x\),且最大的数)
- 求 \(x\) 的后继(后继定义为大于 \(x\),且最小的数)
输入输出格式
输入格式:
第一行为 \(n\) ,表示操作的个数,下面 \(n\) 行每行有两个数 \(opt\) 和 \(x\) , \(opt\) 表示操作的序号( \(1 \leq opt \leq 6\) )
输出格式:
对于操作 \(3,4,5,6\) 每行输出一个数,表示对应答案
输入输出样例
输入样例#1: 复制
10
1 106465
4 1
1 317721
1 460929
1 644985
1 84185
1 89851
6 81968
1 492737
5 493598
输出样例#1: 复制
106465
84185
492737
说明
时空限制:\(1000ms,128M\)
1.\(n\) 的数据范围: \(n \leq 100000\)
2.每个数的数据范围: \([-{10}^7, {10}^7]\)
题解
PS:本题解就是给yyb巨佬的讲解贴个代码的
推荐一发yyb巨佬的博客Splay入门解析
本来是2月份学的东西。已经忘得差不多了。之前打的是treap,现在不会啦。
那就打打抄模板吧。
对于这六个操作。就第三个说一下吧。
先找到这个值的节点。splay到根。那么它的排名就是它的左子树+1。
特别注意我们要先加虚点\(-inf\)和\(inf\)。
因为刚刚加点又删点时,查前驱和后继会卡死循环。
代码
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<iostream>
using namespace std;
const int N=1e6+5;
struct node{
int ch[2],size;
int ff,cnt,val;
}t[N];
int n,m,root=0,tot;
int read(){
int x=0,w=1;char ch=getchar();
while(ch>'9'||ch<'0'){if(ch=='-')w=-1;ch=getchar();}
while(ch>='0'&&ch<='9')x=x*10+ch-'0',ch=getchar();
return x*w;
}
void pushup(int x){
t[x].size=t[t[x].ch[1]].size+t[t[x].ch[0]].size+t[x].cnt;
}
void rotate(int x){
int y=t[x].ff,z=t[y].ff,k=t[y].ch[1]==x;
t[z].ch[t[z].ch[1]==y]=x;t[x].ff=z;
t[y].ch[k]=t[x].ch[k^1];t[t[x].ch[k^1]].ff=y;
t[y].ff=x;t[x].ch[k^1]=y;
pushup(y);pushup(x);
}
void splay(int x,int rt){
while(t[x].ff!=rt){
int y=t[x].ff,z=t[y].ff;
if(z!=rt)
(t[y].ch[0]==x)^(t[z].ch[0]==y)?rotate(x):rotate(y);rotate(x);
}
if(rt==0)root=x;
}
void find(int x){
int u=root;if(!u)return ;
while(t[u].ch[t[u].val<x]&&x!=t[u].val)
u=t[u].ch[t[u].val<x];splay(u,0);
}
void insert(int x){
int u=root,ff=0;
while(u&&x!=t[u].val)
ff=u,u=t[u].ch[t[u].val<x];
if(u)t[u].cnt++;
else{
u=++tot;
if(ff)t[ff].ch[t[ff].val<x]=u;
t[tot].ff=ff;t[tot].val=x;
t[tot].cnt=1;t[tot].size=1;
}splay(u,0);
}
int Next(int x,int f){
find(x);int u=root;
// cout<<x<<' '<<u<<' '<<f<<endl;
if(t[u].val>x&&f)return u;
if(t[u].val<x&&!f)return u;
u=t[u].ch[f];
while(t[u].ch[f^1])u=t[u].ch[f^1];
return u;
}
void delet(int x){
int pre=Next(x,0),nex=Next(x,1);
splay(pre,0);splay(nex,pre);
int u=t[nex].ch[0];
if(t[u].cnt>1)t[u].cnt--,splay(u,0);
else t[nex].ch[0]=0;
}
int kth(int x){
int u=root;
if(t[u].size<x)return 0;
while(233){
int y=t[u].ch[0];
if(x>t[y].size+t[u].cnt){
x-=t[y].size+t[u].cnt;
u=t[u].ch[1];
}
else {
if(t[y].size>=x)u=y;
else return t[u].val;
}
}
}
int main(){
m=read();insert(-1e9);insert(1e9);
while(m--){
int opt=read();
if(opt==1){int x=read();insert(x);}
if(opt==2){int x=read();delet(x);}
if(opt==3){int x=read();find(x);printf("%d\n",t[t[root].ch[0]].size);}
if(opt==4){int x=read();x++;printf("%d\n",kth(x));}
if(opt==5){int x=read();printf("%d\n",t[Next(x,0)].val);}
if(opt==6){int x=read();printf("%d\n",t[Next(x,1)].val);}
}
return 0;
}