笔记:Splay简约写法

待办:

  • 证明treap、splay的复杂度是logn

疑问:

  • splay的常数是不是有点大?
  • splay删除的具体细节。

splay

存树:

struct node {
	int s[2]; // 左右儿子
	int p; // 父节点
	int v; // 点权
	int cnt; // 重复记录
	int siz; // 子树大小
	void init(int p1, int v1) { // 初始化父亲和权值
		p = p1, v =v1;
		cnt = siz = 1;
	}
}tr[maxn];

左右子树简写:

#define ls(x) tr[x].s[0]
#define rs(x) tr[x].s[1]

向上统计,只需要节点数:

void pushup(int x) {
	tr[x].siz = tr[ls(x)].siz + tr[rs(x)].siz +tr[x].cnt;
}

旋转

旋转:下面的方法非常简洁清新,支持一个函数完成左右旋。下图为右旋示意图,左旋同理的对称操作,用异或可以完成。

image

void rotate(int x)
{
	int y = tr[x].p, z = tr[y].p; // y是x的父,z是x的爷
	int k = tr[y].s[1] == x; // x为y的左子节点k=0,右子节点k=1
	tr[z].s[tr[z].s[1] == y] = x;
	tr[x].p = z;
	
	tr[y].s[k] = tr[x].s[k ^ 1]; // 获取x的一个儿子
	tr[tr[x].s[k ^ 1]].p = y;
	
	tr[x].s[k ^ 1] = y; // x 与 y的关系
	tr[y].p = x;
	pushup(y), pushup(x); // 先后顺序
}

可以用define简写

伸展

伸展:核心操作。分为单旋和双旋,双旋分为直线型和折线形。

(1)y是根,单旋
(2)y不是根,直线型
(3)y不是根,折线形

image

void splay(int x, int k)
{
	while (tr[x].p != k) {
		int y = tr[x].p, z = tr[y].p;
		if (z != k) /*判断单双旋*/{
			(ls(y) == x) ^ (ls(z) == y) ? rotate(x) : rotate(y);//判断直线折现
		}
		rotate(x);
	}
	if (k == 0) // 转根
		root = x;
}

k>0时,把x转到k下面;k=0时,把x转到根

查找:找到v,并移至根

void find(int v) {
	int x = root;
	while (tr[x].s[v > tr[x].v] && v != tr[x].v) {
		x = tr[x].s[v > tr[x].v];
	}
	splay(x, 0); // 找不到会转最近的一个
}

v > tr[x].v用来确定左右子树

前驱、后继

找v的前驱与后继。

先调到树根。前驱往左子树找。后继往右子树找。不存在的点返回最接近这个点的点。

int getpre(int v) {
	find(v);
	int x = root;
	if (tr[x].v < v) return x;
	x = ls(x);
	while (rs(x)) x = rs(x);
	return x;
}
int getnxt(int v){
  find(v);
  int x=root;
  if(tr[x].v>v) return x;
  x=rs(x);
  while(ls(x))x=ls(x);
  return x;
}

删除

若有多个相同的数,只删除1个。前驱后继夹叶子。

void del(int v)
{
	int pre = getpre(v);
	int suc = getnxt(v);
	splay(pre, 0), splay(suc, pre);
	int d = tr[suc].s[0];
	if (tr[d].cnt > 1) {
		tr[d].cnt--, splay(del,0);
	} else {
		tr[suc].s[0] = 0, splay(suc,0);
	}
}

哨兵

设置无限大无限小点,为了删除最大最小点

查值排名 Getrank

十分简单一句话

int get_rank(int v){find(v);return tr[tr[root].s[0]].size;}

查排名值 Getval

查值得时候把排名+1,有个哨兵。

int get_val(int k) {
	int x = root;
	while (true) {
		int y = tr.s[0];
		if (tr[y].size + tr[x].cnt< k) { // 判断去右子树
			k -= tr[y].size + tr[x].cnt;
			x = tr[x].s[1]; // 去右子树
		} else {
			if (tr[y].size >= k) x = tr[x].s[0]; // 左子树 x继续走
			else break; // 左右子树都不能走
		}
	}
	splay(x,0); // must have?
	return tr[x].v;
}

模板题代码

#include <bits/stdc++.h>

using namespace std;

const int maxn = 100010;
const int INF = 1e9;

#define ls(x) tr[x].s[0]
#define rs(x) tr[x].s[1]

int n, m;
struct node
{
    int s[2]; // 左右儿子
    int p;    // 父节点
    int v;    // 点权
    int cnt;  // 重复记录
    int siz;  // 子树大小
    void init(int p1, int v1)
    { // 初始化父亲和权值
        p = p1, v = v1;
        cnt = siz = 1;
    }
} tr[maxn];
int root; // 根节点编号
int idx;  // 节点个数

void pushup(int x)
{
    tr[x].siz = tr[ls(x)].siz + tr[rs(x)].siz + tr[x].cnt;
}
void rotate(int x)
{
    int y = tr[x].p, z = tr[y].p; // y是x的父,z是x的爷
    int k = tr[y].s[1] == x;      // x为y的左子节点k=0,右子节点k=1
    tr[z].s[tr[z].s[1] == y] = x;
    tr[x].p = z;

    tr[y].s[k] = tr[x].s[k ^ 1]; // 获取x的一个儿子
    tr[tr[x].s[k ^ 1]].p = y;

    tr[x].s[k ^ 1] = y; // x 与 y的关系
    tr[y].p = x;
    pushup(y), pushup(x); // 先后顺序
}
void splay(int x, int k)
{
    while (tr[x].p != k)
    {
        int y = tr[x].p, z = tr[y].p;
        if (z != k)
        {
            (ls(y) == x) ^ (ls(z) == y) ? rotate(x) : rotate(y);
        }
        rotate(x);
    }
    if (k == 0)
        root = x;
}
void find(int v)
{
    int x = root;
    while (tr[x].s[v > tr[x].v] && v != tr[x].v)
    {
        x = tr[x].s[v > tr[x].v];
    }
    splay(x, 0); // 找不到会转最近的一个
}
int getpre(int v)
{
    find(v);
    int x = root;
    if (tr[x].v < v)
        return x;
    x = ls(x);
    while (rs(x))
        x = rs(x);
    return x;
}
int getnxt(int v)
{ // 后继
    find(v);
    int x = root;
    if (tr[x].v > v)
        return x;
    x = rs(x);
    while (ls(x))
        x = ls(x);
    return x;
}
void del(int v)
{
    int pre = getpre(v);
    int suc = getnxt(v);
    splay(pre, 0), splay(suc, pre);
    int d = tr[suc].s[0];
    if (tr[d].cnt > 1)
    {
        tr[d].cnt--;
        splay(d, 0);
    }
    else
    {
        tr[suc].s[0] = 0;
        splay(suc, 0);
    }
}
int get_rank(int v)
{
    find(v);
    return tr[tr[root].s[0]].siz;
}
int get_val(int k)
{
    int x = root;
    while (true)
    {
        int y = ls(x);
        if (tr[y].siz + tr[x].cnt < k)
        { // 判断去右子树
            k -= tr[y].siz + tr[x].cnt;
            x = tr[x].s[1]; // 去右子树
        }
        else
        {
            if (tr[y].siz >= k)
                x = tr[x].s[0]; // 左子树 x继续走
            else
                break; // 左右子树都不能走
        }
    }
    splay(x, 0); // must have?
    return tr[x].v;
}
void insert(int v)
{
    int x = root, p = 0; // parent p
    while (x && tr[x].v != v)
    {
        p = x, x = tr[x].s[v > tr[x].v];
    }
    if (x)
        tr[x].cnt++;
    else
    {
        x = ++idx;
        tr[p].s[v > tr[p].v] = x;
        tr[x].init(p, v);
    }
    splay(x, 0); // if haven't?
}
int main()
{
    insert(-INF);
    insert(INF); // 哨兵
    scanf("%d", &n);
    while (n--)
    {
        int op, x;
        scanf("%d%d", &op, &x);
        if (op == 1)
            insert(x);
        if (op == 2)
            del(x);
        if (op == 3)
            printf("%d\n", get_rank(x));
        if (op == 4)
            printf("%d\n", get_val(x + 1));
        if (op == 5)
            printf("%d\n", tr[getpre(x)].v);
        if (op == 6)
            printf("%d\n", tr[getnxt(x)].v);
    }
    return 0;
}

posted @ 2022-12-25 16:45  Vegdie  阅读(19)  评论(0编辑  收藏  举报