笔记:Splay简约写法
待办:
- 证明treap、splay的复杂度是logn
疑问:
- splay的常数是不是有点大?
- splay删除的具体细节。
splay
存树:
struct node {
int s[2]; // 左右儿子
int p; // 父节点
int v; // 点权
int cnt; // 重复记录
int siz; // 子树大小
void init(int p1, int v1) { // 初始化父亲和权值
p = p1, v =v1;
cnt = siz = 1;
}
}tr[maxn];
左右子树简写:
#define ls(x) tr[x].s[0]
#define rs(x) tr[x].s[1]
向上统计,只需要节点数:
void pushup(int x) {
tr[x].siz = tr[ls(x)].siz + tr[rs(x)].siz +tr[x].cnt;
}
旋转
旋转:下面的方法非常简洁清新,支持一个函数完成左右旋。下图为右旋示意图,左旋同理的对称操作,用异或可以完成。
void rotate(int x)
{
int y = tr[x].p, z = tr[y].p; // y是x的父,z是x的爷
int k = tr[y].s[1] == x; // x为y的左子节点k=0,右子节点k=1
tr[z].s[tr[z].s[1] == y] = x;
tr[x].p = z;
tr[y].s[k] = tr[x].s[k ^ 1]; // 获取x的一个儿子
tr[tr[x].s[k ^ 1]].p = y;
tr[x].s[k ^ 1] = y; // x 与 y的关系
tr[y].p = x;
pushup(y), pushup(x); // 先后顺序
}
可以用define简写
伸展
伸展:核心操作。分为单旋和双旋,双旋分为直线型和折线形。
(1)y是根,单旋
(2)y不是根,直线型
(3)y不是根,折线形
void splay(int x, int k)
{
while (tr[x].p != k) {
int y = tr[x].p, z = tr[y].p;
if (z != k) /*判断单双旋*/{
(ls(y) == x) ^ (ls(z) == y) ? rotate(x) : rotate(y);//判断直线折现
}
rotate(x);
}
if (k == 0) // 转根
root = x;
}
k>0时,把x转到k下面;k=0时,把x转到根
查找:找到v,并移至根
void find(int v) {
int x = root;
while (tr[x].s[v > tr[x].v] && v != tr[x].v) {
x = tr[x].s[v > tr[x].v];
}
splay(x, 0); // 找不到会转最近的一个
}
v > tr[x].v
用来确定左右子树
前驱、后继
找v的前驱与后继。
先调到树根。前驱往左子树找。后继往右子树找。不存在的点返回最接近这个点的点。
int getpre(int v) {
find(v);
int x = root;
if (tr[x].v < v) return x;
x = ls(x);
while (rs(x)) x = rs(x);
return x;
}
int getnxt(int v){
find(v);
int x=root;
if(tr[x].v>v) return x;
x=rs(x);
while(ls(x))x=ls(x);
return x;
}
删除
若有多个相同的数,只删除1个。前驱后继夹叶子。
void del(int v)
{
int pre = getpre(v);
int suc = getnxt(v);
splay(pre, 0), splay(suc, pre);
int d = tr[suc].s[0];
if (tr[d].cnt > 1) {
tr[d].cnt--, splay(del,0);
} else {
tr[suc].s[0] = 0, splay(suc,0);
}
}
哨兵
设置无限大无限小点,为了删除最大最小点
查值排名 Getrank
十分简单一句话
int get_rank(int v){find(v);return tr[tr[root].s[0]].size;}
查排名值 Getval
查值得时候把排名+1,有个哨兵。
int get_val(int k) {
int x = root;
while (true) {
int y = tr.s[0];
if (tr[y].size + tr[x].cnt< k) { // 判断去右子树
k -= tr[y].size + tr[x].cnt;
x = tr[x].s[1]; // 去右子树
} else {
if (tr[y].size >= k) x = tr[x].s[0]; // 左子树 x继续走
else break; // 左右子树都不能走
}
}
splay(x,0); // must have?
return tr[x].v;
}
模板题代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 100010;
const int INF = 1e9;
#define ls(x) tr[x].s[0]
#define rs(x) tr[x].s[1]
int n, m;
struct node
{
int s[2]; // 左右儿子
int p; // 父节点
int v; // 点权
int cnt; // 重复记录
int siz; // 子树大小
void init(int p1, int v1)
{ // 初始化父亲和权值
p = p1, v = v1;
cnt = siz = 1;
}
} tr[maxn];
int root; // 根节点编号
int idx; // 节点个数
void pushup(int x)
{
tr[x].siz = tr[ls(x)].siz + tr[rs(x)].siz + tr[x].cnt;
}
void rotate(int x)
{
int y = tr[x].p, z = tr[y].p; // y是x的父,z是x的爷
int k = tr[y].s[1] == x; // x为y的左子节点k=0,右子节点k=1
tr[z].s[tr[z].s[1] == y] = x;
tr[x].p = z;
tr[y].s[k] = tr[x].s[k ^ 1]; // 获取x的一个儿子
tr[tr[x].s[k ^ 1]].p = y;
tr[x].s[k ^ 1] = y; // x 与 y的关系
tr[y].p = x;
pushup(y), pushup(x); // 先后顺序
}
void splay(int x, int k)
{
while (tr[x].p != k)
{
int y = tr[x].p, z = tr[y].p;
if (z != k)
{
(ls(y) == x) ^ (ls(z) == y) ? rotate(x) : rotate(y);
}
rotate(x);
}
if (k == 0)
root = x;
}
void find(int v)
{
int x = root;
while (tr[x].s[v > tr[x].v] && v != tr[x].v)
{
x = tr[x].s[v > tr[x].v];
}
splay(x, 0); // 找不到会转最近的一个
}
int getpre(int v)
{
find(v);
int x = root;
if (tr[x].v < v)
return x;
x = ls(x);
while (rs(x))
x = rs(x);
return x;
}
int getnxt(int v)
{ // 后继
find(v);
int x = root;
if (tr[x].v > v)
return x;
x = rs(x);
while (ls(x))
x = ls(x);
return x;
}
void del(int v)
{
int pre = getpre(v);
int suc = getnxt(v);
splay(pre, 0), splay(suc, pre);
int d = tr[suc].s[0];
if (tr[d].cnt > 1)
{
tr[d].cnt--;
splay(d, 0);
}
else
{
tr[suc].s[0] = 0;
splay(suc, 0);
}
}
int get_rank(int v)
{
find(v);
return tr[tr[root].s[0]].siz;
}
int get_val(int k)
{
int x = root;
while (true)
{
int y = ls(x);
if (tr[y].siz + tr[x].cnt < k)
{ // 判断去右子树
k -= tr[y].siz + tr[x].cnt;
x = tr[x].s[1]; // 去右子树
}
else
{
if (tr[y].siz >= k)
x = tr[x].s[0]; // 左子树 x继续走
else
break; // 左右子树都不能走
}
}
splay(x, 0); // must have?
return tr[x].v;
}
void insert(int v)
{
int x = root, p = 0; // parent p
while (x && tr[x].v != v)
{
p = x, x = tr[x].s[v > tr[x].v];
}
if (x)
tr[x].cnt++;
else
{
x = ++idx;
tr[p].s[v > tr[p].v] = x;
tr[x].init(p, v);
}
splay(x, 0); // if haven't?
}
int main()
{
insert(-INF);
insert(INF); // 哨兵
scanf("%d", &n);
while (n--)
{
int op, x;
scanf("%d%d", &op, &x);
if (op == 1)
insert(x);
if (op == 2)
del(x);
if (op == 3)
printf("%d\n", get_rank(x));
if (op == 4)
printf("%d\n", get_val(x + 1));
if (op == 5)
printf("%d\n", tr[getpre(x)].v);
if (op == 6)
printf("%d\n", tr[getnxt(x)].v);
}
return 0;
}