算法学习:伸展树(splay)
【定义】
【平衡树】 每个叶子结点的深度差不超过1的二叉树
【伸展树】
【常用问题】
splay的操作,通过左旋右旋,将某个结点通过旋转旋转至根节点,使树的结构发生变化,尽可能的平衡
并且因为左旋右旋的性质,当原树是一个二叉排序树的时候,splay依旧能够使原树保持二叉排序树的性质
左旋右旋图片
【模板题】
【luogu P3369】普通平衡树
【题意】实现一颗二叉排序树的增删查改
【注】对数据结构的理解见注释
【代码】
#include<cstdio>
#include<iostream>
#include<algorithm>
using namespace std;
const int MAXN = 100010;
const int INF = 10000000;
class Splay
{
#define root e[0].ch[1]
private:
class node
{
public:
int v, father;
int ch[2];
int sum;//子树结点个数+自己的结点个数
int recy;//纪录自己被重复多少次
};
node e[MAXN];
int n, points;
public:
void update(int x)//访问左右子树,更新树上所储存的数据
{
e[x].sum = e[e[x].ch[0]].sum + e[e[x].ch[1]].sum + e[x].recy;
}
int identify(int x)
{
return e[e[x].father].ch[0] == x ? 0 : 1;
//自己是不是左孩子
}
void connect(int x, int f, int son)
{
e[x].father = f;
e[f].ch[son] = x;
//将结点x作为结点f的孩子
return;
}
void rotate(int x)
{
//旋转操作的语言描述
//将被指定的节点向上移动一级,并将原有的父级节点作为自己的儿子
//把自己的儿子的位置给自己的爸爸当新儿子
//把自己爸爸放在自己原来的位置
int y = e[x].father;
//自己爸爸
int mroot = e[y].father;
//爸爸的爸爸
int mrootson = identify(y);
//自己爸爸对于爷爷的位置
int yson = identify(x);
//自己对于爸爸的位置
int B = e[x].ch[yson ^ 1];
//自己需要爸爸继承的自己的儿子
connect(B, y, yson);
//把自己需要爸爸带的儿子先给爸爸
connect(y, x, (yson ^ 1));
//把爸爸给自己已经把儿子给爸爸之后空出来的位置给爸爸
connect(x, mroot, mrootson);
//把自己给爷爷
update(y), update(x);
}
void splay(int at, int to)
{
to = e[to].father;
while (e[at].father != to)
{
int up = e[at].father;
if (e[up].father == to) rotate(at);
else if (identify(up) == identify(at))
{
rotate(up);
rotate(at);
}
else
{
rotate(at);
rotate(at);
}
}
}
int crepoint(int v, int father)
{
n++;
e[n].v = v;
e[n].father = father;
e[n].sum = e[n].recy = 1;
return n;
}
void destroy(int x)
{
e[x].v = e[x].ch[0] = e[x].ch[1] = e[x].sum = e[x].father = e[x].recy = 0;
if (x == n) n--;
}
int find(int v)
{
int now = root;
while (true)
{
if (e[now].v == v)
{
splay(now, root);
return now;
}
int next = v < e[now].v ? 0 : 1;
if (!e[now].ch[next]) return 0;
now = e[now].ch[next];
}
}
int build(int v)
{
points++;
if (n == 0)
{
root = 1;
crepoint(v, 0);
}
else
{
int now = root;
while (true)
{
e[now].sum++;
if (v == e[now].v)
{
e[now].recy++;
return now;
}
int next = v < e[now].v ? 0 : 1;
if (!e[now].ch[next])
{
crepoint(v, now);
e[now].ch[next] = n;
return n;
}
now = e[now].ch[next];
}
}
return 0;
}
void push(int v)
{
int add = build(v);
splay(add, root);
}
void pop(int v)//删除结点
{
int deal = find(v);
//找到结点
if (!deal) return;
points--;
if (e[deal].recy > 1)
{
e[deal].recy--;
e[deal].sum--;
return;
}
//直接删除
//去掉这个点
if (!e[deal].ch[0])
{
root = e[deal].ch[1];
e[root].father = 0;
}
else
{
int lef = e[deal].ch[0];
//lef,他的左儿子
while (e[lef].ch[1])
lef = e[lef].ch[1];
//找到他最右的结点,也就是这颗树上最小的值
splay(lef, e[deal].ch[0]);
//将这棵树旋到左结点
int rig = e[deal].ch[1];
connect(rig, lef, 1); connect(lef, 0, 1);
update(lef);
}
destroy(deal);
}
int rank(int v)
{
int ans = 0, now = root;
while (true)
{
if (e[now].v == v)
{
ans = ans + e[e[now].ch[0]].sum + 1;
if (now) splay(now, root);
return ans;
}
if (now == 0) return 0;
if (v < e[now].v) now = e[now].ch[0];
else
{
ans = ans + e[e[now].ch[0]].sum + e[now].recy;
now = e[now].ch[1];
}
}
return 0;
}
int atrank(int x)
{
if (x > points) return -INF;
int now = root;
while (true)
{
int minused = e[now].sum - e[e[now].ch[1]].sum;
//左子树的个数
if (x > e[e[now].ch[0]].sum && x <= minused) break;
//如果这个数在这个范围内
if (x < minused) now = e[now].ch[0];
//如果小于,说明这个数在左子树中
else
{
x = x - minused;
now = e[now].ch[1];
}
//同上
}
splay(now, root);
return e[now].v;
}
int upper(int v)
{
int now = root;
int result = INF;
while (now)
{
if (e[now].v > v && e[now].v < result) result = e[now].v;
if (v < e[now].v)
now = e[now].ch[0];
else
now = e[now].ch[1];
}
return result;
}
int lower(int v)
{
int now = root;
int result = -INF;
while (now)
{
if (e[now].v < v && e[now].v > result) result = e[now].v;
if (v > e[now].v)
now = e[now].ch[1];
else
now = e[now].ch[0];
}
return result;
}
#undef root
};
Splay T;
int main()
{
int n;
scanf("%d", &n);
T.push(INF);
T.push(-INF);
while (n--)
{
int p, v;
scanf("%d%d", &p, &v);
switch (p)
{
case 1:
T.push(v); break;
case 2:
T.pop(v); break;
case 3:
printf("%d\n", T.rank(v) - 1); break;
case 4:
printf("%d\n", T.atrank(v + 1)); break;
case 5:
printf("%d\n", T.lower(v)); break;
case 6:
printf("%d\n", T.upper(v)); break;
default:
break;
}
}
}