bzoj 3224: Tyvj 1728 普通平衡树 && loj 104 普通平衡树 (splay树)
题目链接:
https://www.lydsy.com/JudgeOnline/problem.php?id=3224
思路:
splay树模板题:
推荐博客:https://blog.csdn.net/clove_unique/article/details/50630280
b站上splay树的讲解视频也可以看下,讲的很好,推荐看完视频了解了splay的原理再写
实现代码:
#include<bits/stdc++.h> using namespace std; #define ll long long const int M = 1e5+10; const int inf = 1<<30; struct node{ int fa,s[2],x,size,cnt; }tr[M]; int top,root,n; void pushup(int k){ tr[k].size = tr[tr[k].s[0]].size+tr[tr[k].s[1]].size+tr[k].cnt; } void rotate(int k){ int x = tr[k].fa,y = tr[x].fa,b=tr[x].s[1]==k; tr[y].s[tr[y].s[1]==x]=k;tr[k].fa=y; tr[x].s[b]=tr[k].s[b^1]; tr[tr[k].s[b^1]].fa=x; tr[k].s[b^1]=x; tr[x].fa = k; pushup(x); pushup(k); } void splay(int x,int go){ //伸展操作 int y = tr[x].fa,z = tr[y].fa; for(;tr[x].fa^go;rotate(x)){ y = tr[x].fa,z = tr[y].fa; if(z^go) (tr[y].s[0] == x)^(tr[z].s[0]==y)?rotate(x):rotate(y); } if(!go) root = x; } void insert(int x){ //插入操作 int u = root,fa = 0; for(;u&&tr[u].x^x;u=tr[u].s[x>tr[u].x]) fa = u; if(u) tr[u].cnt++; //若已经有这个数了,到cnt++即可 else{ //新建一个节点 u = ++top; if(fa) tr[fa].s[x>tr[fa].x] = u; tr[top].fa = fa;tr[top].x = x;tr[top].cnt = tr[top].size = 1; } splay(u,0); //splay 到根 } //这个函数还有一个功能是把一个跟x不相隔仍荷属的数转移到根 int rank(int x){ //查找x的排名 int u = root; if(!u) return 0; for(;tr[u].s[x>tr[u].x]&&x^tr[u].x;u=tr[u].s[x>tr[u].x]); splay(u,0); //splay 到根 return tr[tr[u].s[0]].size; } //求x的前驱后继 int next(int x,int f){ //f=0表示前驱,f=1表示后继,返回的是编号 rank(x); //将与x不相隔任何数的数转移到根 int u = root; if(tr[u].x>x&&f||tr[u].x<x&&!f) return u; //若已经是答案,则返回 u = tr[u].s[f]; while(tr[u].s[f^1]) u = tr[u].s[f^1]; //查找前驱时在左子树中查找最大值,后继反之 return u; } void erase(int x){ int up = next(x,1),lo = next(x,0),u; //得到前驱后继 splay(lo,0); splay(up,lo); //移成一个好局面 u = tr[up].s[0]; if(tr[u].cnt > 1){ tr[u].cnt--; splay(u,0); } else { tr[up].s[0] = 0; pushup(up); pushup(lo); } } int find(int x){ x++; int u = root,son; if(tr[u].size < x) return 0; //没有这个数 while(1){ son = tr[u].s[0]; if(tr[son].size >= x) u = son; else if(x>tr[son].size+tr[u].cnt) x -= tr[son].size + tr[u].cnt, u = tr[u].s[1]; else break; } splay(u,0); return tr[u].x; } int main(){ scanf("%d",&n); int op,x; insert(-inf); insert(inf); for(int i = 1;i <= n;i ++){ scanf("%d%d",&op,&x); switch(op){ case 1: insert(x);break; case 2: erase(x); break; case 3: printf("%d\n",rank(x));break; case 4: printf("%d\n",find(x)); break; case 5: printf("%d\n",tr[next(x,0)].x); break; case 6: printf("%d\n",tr[next(x,1)].x); break; } } return 0; }