Splay学习笔记 & P3369 【模板】普通平衡树
传送门
Splay
细节是真他妈多,写了一天,写吐了。
后悔没先学treap了。
需要实现以下函数:
- void Init():重置整棵树(删除了整棵树的时候用)
- int New(int val,int fa):新建一个节点,权值为val,父亲为fa,返回节点编号
- void Delete(int x):清空节点x的信息
- void update(int x):更新x节点的siz大小(类似线段树的push up)
- void rotate(int x):把节点x旋转到x父亲的位置,左旋右旋可以写在一起
- void splay(int x,int goal):把x通过不断旋转,直到父亲是goal
- bool find(int val):找到值为val的节点,并旋转到根,返回是否成功找到
- int pre():找到第一个小于根节点val的节点的编号
- int nxt():找到第一个大于根节点val的节点的编号
- void insert(int val):插入值为val的点
- void del(int val):删除值为val的点,注意若整棵树就一个节点需要Init()一下
- int getrk(int val):返回值为val的节点的排名
- int getval(int x):返回排名为x的节点权值
Q:要注意啥?
A:呵呵,真的没啥要注意的,就注意别写挂了行了/kx/kx/kx/kx
AC代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<queue>
#include<stack>
#include<map>
#include<vector>
using namespace std;
template<class T>inline void read(T &x)
{
x=0;register char c=getchar();register bool f=0;
while(!isdigit(c))f^=c=='-',c=getchar();
while(isdigit(c))x=(x<<3)+(x<<1)+(c^48),c=getchar();
if(f)x=-x;
}
template<class T>inline void print(T x)
{
if(x<0)putchar('-'),x=-x;
if(x>9)print(x/10);
putchar('0'+x%10);
}
const int maxn=1e5+5;
int n,rt,cnt;
struct node{
int fa,son[2],siz,num,val;
}tr[maxn];
inline void Init(){
tr[rt].fa=tr[rt].son[0]=tr[rt].son[1]=tr[rt].siz=tr[rt].num=0;
rt=0;
}
inline int New(int val,int fa){
cnt++;
tr[cnt].fa=fa;
tr[cnt].val=val;
tr[cnt].num=tr[cnt].siz=1;
return cnt;
}
inline void Delete(int x){
tr[x].fa=tr[x].son[0]=tr[x].son[1]=tr[x].siz=tr[x].num=0;
}
inline void update(int x){
if(!x) return;
tr[x].siz=tr[x].num;
if(tr[x].son[0]) tr[x].siz+=tr[tr[x].son[0]].siz;
if(tr[x].son[1]) tr[x].siz+=tr[tr[x].son[1]].siz;
}
void rotate(int x){
int y=tr[x].fa,z=tr[y].fa;
int c=(tr[y].son[1]==x);
tr[y].son[c]=tr[x].son[!c];
tr[tr[x].son[!c]].fa=y;
tr[x].son[!c]=y;
if(z) tr[z].son[tr[z].son[1]==y]=x;
tr[y].fa=x;
tr[x].fa=z;
update(y);
update(x);
}
void splay(int x,int goal){
while(tr[x].fa!=goal){
int y=tr[x].fa,z=tr[y].fa;
if(z!=goal) ((tr[y].son[1]==x)^(tr[z].son[1]==y))?rotate(x):rotate(y);
rotate(x);
}
if(!goal) rt=x;
}
bool find(int val){
int x=rt;
while(1){
if(tr[x].val==val) return splay(x,0),1;
if(tr[x].son[tr[x].val<val]) x=tr[x].son[tr[x].val<val];
else return 0;
}
}
int pre(){
int x=tr[rt].son[0];
while(tr[x].son[1]) x=tr[x].son[1];
return x;
}
int nxt(){
int x=tr[rt].son[1];
while(tr[x].son[0]) x=tr[x].son[0];
return x;
}
void insert(int val){
if(!rt){
rt=New(val,0);
return;
}
if(find(val)){
tr[rt].num++;
tr[rt].siz++;
return;
}
int x=rt;
while(1){
if(tr[x].son[tr[x].val<val]) x=tr[x].son[tr[x].val<val];
else{
tr[x].son[tr[x].val<val]=New(val,x);
update(x);
splay(tr[x].son[tr[x].val<val],0);
return;
}
}
}
void del(int val){
find(val);
int x=rt;
tr[rt].num--;
tr[rt].siz--;
if(tr[rt].num) return;
if(!tr[rt].son[0]&&!tr[rt].son[1]){
Init();
return;
}
if(!tr[rt].son[0]){
rt=tr[rt].son[1];
tr[rt].fa=0;
Delete(x);
return;
}
if(!tr[rt].son[1]){
rt=tr[rt].son[0];
tr[rt].fa=0;
Delete(x);
return;
}
find(tr[pre()].val);
splay(x,rt);
if(tr[x].son[1]){
tr[rt].son[1]=tr[x].son[1];
tr[tr[x].son[1]].fa=rt;
}
Delete(x);
}
int getrk(int val){
int x=rt,res=0;
while(1){
if(val==tr[x].val){
res+=tr[x].son[0]?tr[tr[x].son[0]].siz+1:1;
splay(x,0);
return res;
}
if(val<tr[x].val){
if(!tr[x].son[0]) return res;
x=tr[x].son[0];
}else{
res+=tr[x].num;
if(tr[x].son[0]) res+=tr[tr[x].son[0]].siz;
if(!tr[x].son[1]) return res;
x=tr[x].son[1];
}
}
}
int getval(int tot){
int x=rt;
while(1){
if(tr[x].son[0]&&tot<=tr[tr[x].son[0]].siz) x=tr[x].son[0];
else{
tot-=tr[tr[x].son[0]].siz;
if(tot<=tr[x].num) return tr[x].val;
tot-=tr[x].num;
x=tr[x].son[1];
}
}
}
int main(){
read(n);
for(int i=1;i<=n;i++){
int op,x;
read(op);read(x);
if(op==1) insert(x);
if(op==2) del(x);
if(op==3) print(getrk(x)),puts("");
if(op==4) print(getval(x)),puts("");
if(op==5) insert(x),print(tr[pre()].val),puts(""),del(x);
if(op==6) insert(x),print(tr[nxt()].val),puts(""),del(x);
}
return 0;
}