浅谈平衡树

什么是平衡树

平衡树其实是二叉搜索树的优化,满足 BST 1 性质。

关于平衡树的种类其实有很多,但本文不涉及太多,我们讲讲最常用的 33 中平衡树吧。

  • TreapTreap
  • fhq Treapfhq \ Treap
  • SplaySplay

先想想二叉搜索树为什么要优化 ?

当然是因为如果我们要插入一连串且非常多的数时,二叉搜索树会被卡成一条链。

而平衡树又是怎么优化的呢 ?

其实是在满足 BST 1 性质时,通过旋转将树拍扁,这样就可以优化时间复杂度。

平衡树模板

要求支持一下操作:

  1. 插入 xx
  2. 删除 xx 数(若有多个相同的数,只删除一个)
  3. 查询 xx 数的排名(排名定义为比当前数小的数的个数 +1+1 )
  4. 查询排名为 xx 的数
  5. xx 的前驱(前驱定义为小于 xx,且最大的数)
  6. xx 的后继(后继定义为大于 xx,且最小的数)

小细节:有时候需要在平衡树中插入正无穷大和负无穷大,以便确定边界。

Treap

TreapTreap,顾名思义就是 Tree(树)+Heap(堆)Tree(树)+Heap(堆) ,说明 TreapTreap 不仅满足 BST 1 性质,还满足堆的性质 2

算法思想

旋转

TreapTreap 的主要思想其实是左旋和右旋。

右旋的操作流程是:假如要对点 yy 进行右旋,且 xxyy 的左儿子,那么 yy 就会成为点 xx 的右儿子。此时若 xx 原本有右儿子,则会发生冲突,所以还要令 xx 原本的右儿子变成 yy 的左儿子,此时仍然满足 BST 1 性质。左旋的操作流程类似,只是把方向调换过来。具体见下图:

在这里插入图片描述

因为我们不知道何时旋转,所以 听天由命\color{SpringGreen}{听天由命} 吧。

即除了二叉搜索树上的权值之外,我们还可以给每个节点加上一个随机的权值。

接下来,我们利用左旋和右旋,将树的形态调整至满足如下性质:

  • 对于原本的权值,这棵树必须满足二叉搜索树的 BST 1 性质,即节点 xx 原本的权值一定大于其左子树中任意一个节点的原本权值且小于其右子树中任意一个节点的原本权值。

  • 对于随机权值,这棵树必须满足堆的性质 2,即节点 xx 的随机权值比其两个子节点的随机权值都大或者都小,通常我们选用大根堆。

所以对于随机数据,Treap 有很优秀的时间复杂度。

void Rotate(int &id, int d)
{
    int temp = ch[id][d ^ 1];
    ch[id][d ^ 1] = ch[temp][d];
    ch[temp][d] = id;
    id = temp;
    pushup(ch[id][d]), pushup(id);
}

更新节点大小

cntcnt 数组记录的是有多少个点与当前点权值相同(包括当前点),因为我这里是将权值一样的节点合到了一个节点上。

void pushup(int id)
{
    siz[id] = siz[ch[id][0]] + siz[ch[id][1]] + cnt[id];
}

新建节点

int New(int v)
{
    val[++tot] = v;
    dat[tot] = rand();
    siz[tot] = 1;
    cnt[tot] = 1;
    return tot;
}

插入

利用 BST 1 性质插入节点,小于插到左边,大于插到右边,没有就新建,然后如果子节点的优先级要大于父节点,我们就把子节点旋转上去。

void insert(int &id, int v)
{
    if (!id)
    {
        id = New(v);
        return;
    }
    if (v == val[id])
        cnt[id]++;
    else
    {
        int d = v < val[id] ? 0 : 1;
        insert(ch[id][d], v);
        if (dat[id] < dat[ch[id][d]])
            Rotate(id, d ^ 1);
    }
    pushup(id);
}

删除

将这个点优先级大的子节点旋转上来,自己就会旋转下去,一直将其旋到叶子,然后删除。

void Remove(int &id, int v)
{
    if (!id)
        return;
    if (v == val[id])
    {
        if (cnt[id] > 1)
        {
            cnt[id]--, pushup(id);
            return;
        }
        if (ch[id][0] || ch[id][1])
        {
            if (!ch[id][1] || dat[ch[id][0]] > dat[ch[id][1]])
            {
                Rotate(id, 1), Remove(ch[id][1], v);
            }
            else
                Rotate(id, 0), Remove(ch[id][0], v);
            pushup(id);
        }
        else
            id = 0;
        return;
    }
    v < val[id] ? Remove(ch[id][0], v) : Remove(ch[id][1], v);
    pushup(id);
}

查找排名

其实也很好理解,如果当前点权值小于要找的数,就去左子树找,否则去右子树找。

int get_rank(int id, int v)
{
    if (!id)
        return 0;
    if (v == val[id])
        return siz[ch[id][0]] + 1;
    else if (v < val[id])
        return get_rank(ch[id][0], v);
    else
        return siz[ch[id][0]] + cnt[id] + get_rank(ch[id][1], v);
}

查找数值

这个就更简单了不是吗?

int get_val(int id, int rank)
{
    if (!id)
        return INF;
    if (rank <= siz[ch[id][0]])
        return get_val(ch[id][0], rank);
    else if (rank <= siz[ch[id][0]] + cnt[id])
        return val[id];
    else
        return get_val(ch[id][1], rank - siz[ch[id][0]] - cnt[id]);
}

找前驱和后继

类似上面的查找

int get_pre(int v)
{
    int id = root, pre;
    while (id)
    {
        if (val[id] < v)
            pre = val[id], id = ch[id][1];
        else
            id = ch[id][0];
    }
    return pre;
}
int get_next(int v)
{
    int id = root, nxt;
    while (id)
    {
        if (val[id] > v)
            nxt = val[id], id = ch[id][0];
        else
            id = ch[id][1];
    }
    return nxt;
}

代码实现

#include <bits/stdc++.h>
using namespace std;

int read()
{
    int out = 0, flag = 1;
    char c = getchar();
    while (c < '0' || c > '9')
    {
        if (c == '-')
            flag = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9')
    {
        out = out * 10 + c - '0';
        c = getchar();
    }
    return flag * out;
}

const int maxn = 1000019, INF = 1e9;
int m;
int ch[maxn][2];
int val[maxn], dat[maxn];
int siz[maxn], cnt[maxn];
int tot, root;
int New(int v)
{
    val[++tot] = v;
    dat[tot] = rand();
    siz[tot] = 1;
    cnt[tot] = 1;
    return tot;
}
void pushup(int id)
{
    siz[id] = siz[ch[id][0]] + siz[ch[id][1]] + cnt[id];
}
void Rotate(int &id, int d)
{
    int temp = ch[id][d ^ 1];
    ch[id][d ^ 1] = ch[temp][d];
    ch[temp][d] = id;
    id = temp;
    pushup(ch[id][d]), pushup(id);
}
void insert(int &id, int v)
{
    if (!id)
    {
        id = New(v);
        return;
    }
    if (v == val[id])
        cnt[id]++;
    else
    {
        int d = v < val[id] ? 0 : 1;
        insert(ch[id][d], v);
        if (dat[id] < dat[ch[id][d]])
            Rotate(id, d ^ 1);
    }
    pushup(id);
}
void Remove(int &id, int v)
{
    if (!id)
        return;
    if (v == val[id])
    {
        if (cnt[id] > 1)
        {
            cnt[id]--, pushup(id);
            return;
        }
        if (ch[id][0] || ch[id][1])
        {
            if (!ch[id][1] || dat[ch[id][0]] > dat[ch[id][1]])
            {
                Rotate(id, 1), Remove(ch[id][1], v);
            }
            else
                Rotate(id, 0), Remove(ch[id][0], v);
            pushup(id);
        }
        else
            id = 0;
        return;
    }
    v < val[id] ? Remove(ch[id][0], v) : Remove(ch[id][1], v);
    pushup(id);
}
int get_rank(int id, int v)
{
    if (!id)
        return 0;
    if (v == val[id])
        return siz[ch[id][0]] + 1;
    else if (v < val[id])
        return get_rank(ch[id][0], v);
    else
        return siz[ch[id][0]] + cnt[id] + get_rank(ch[id][1], v);
}
int get_val(int id, int rank)
{
    if (!id)
        return INF;
    if (rank <= siz[ch[id][0]])
        return get_val(ch[id][0], rank);
    else if (rank <= siz[ch[id][0]] + cnt[id])
        return val[id];
    else
        return get_val(ch[id][1], rank - siz[ch[id][0]] - cnt[id]);
}
int get_pre(int v)
{
    int id = root, pre;
    while (id)
    {
        if (val[id] < v)
            pre = val[id], id = ch[id][1];
        else
            id = ch[id][0];
    }
    return pre;
}
int get_next(int v)
{
    int id = root, nxt;
    while (id)
    {
        if (val[id] > v)
            nxt = val[id], id = ch[id][0];
        else
            id = ch[id][1];
    }
    return nxt;
}
int main()
{
    m = read();
    for (int i = 1; i <= m; i++)
    {
        int cmd = read(), x = read();
        if (cmd == 1)
            insert(root, x);
        else if (cmd == 2)
            Remove(root, x);
        else if (cmd == 3)
            printf("%d\n", get_rank(root, x));
        else if (cmd == 4)
            printf("%d\n", get_val(root, x));
        else if (cmd == 5)
            printf("%d\n", get_pre(x));
        else if (cmd == 6)
            printf("%d\n", get_next(x));
    }
    return 0;
}

fhq Treap

FHQ TreapFHQ \ Treap 好理解,上手快,代码一般很短,支持可持久化,可以实现 TreapTreap 的功能,并且不需要 TreapTreap 的旋转操作,所以 FHQ TreapFHQ \ Treap 又被称为 无旋 TreapTreap 或者 非旋 TreapTreap

算法思想

无旋 TreapTreap 的主要操作有 分裂( splitsplit )和合并( mergemerge )两种。

顾名思义,无旋 TreapTreap 保持树平衡的方式就是不断地将树按照某种方式分裂成两棵子树,再通过合并子树来调整节点的祖孙关系。而分裂又分为 按值分裂 和 按大小分裂 两种,通常情况下,我们会选择按值分裂。

无旋 TreapTreapTreapTreap 一样,给每个节点都附上一个新的随机权值 keykey

同样地,原本的点权满足二叉搜索树的性质,随机权值满足堆的性质。

分裂

SplitSplit 的意思就是将这颗二叉树按某种条件掰开两半。

我们将分裂后左边的树定义为 XX,右边的树定义为 YY,它们的根为 xxyy

若以 valval分裂,则其中以 xx 为根的子树满足所有节点的权值都小于等于 valval,以 yy 为根的子树满足所有节点的权值都大于 valval,这就是按值分裂的规则。

假如一棵树要以 66 来掰开,如图 :

在这里插入图片描述

然后大力一掰 ? ? ?

在这里插入图片描述

具体的算法流程也很简单。

假如 root=0root=0,说明当前子树为空树,无法分裂,所以令 x=y=0x=y=0

rootroot 的权值小于等于 valval,说明 rootroot 应该划分在以 xx 为根的子树内。

因为无旋 TreapTreap 是一棵二叉搜索树,所以 rootroot 的左子树中任意一点的权值 root≤root 的权值 val≤valrootroot 的左子树也应该划分入以 xx 为根的子树。

此时右子树里可能会出现点权比 valval 大的节点,所以在右子树内继续递归分裂。

rootroot 的权值大于 valval 的情况同理,将 rootroot 及其右子树划分入以 yy 为根的子树,继续在 rootroot 的左子树内查找即可。

void split(int rt, int val, int &x, int &y)
{
    if (!rt)
    {
        x = y = 0;
        return;
    }
    if (tr[rt].val <= val)
    {
        x = rt;
        split(tr[rt].r, val, tr[rt].r, y);
    }
    else
    {
        y = rt;
        split(tr[rt].l, val, x, tr[rt].l);
    }
    update(rt);
}

合并

合并操作指将以 xx 为根的子树和以 yy 为根的子树合并成一整棵树,并返回新树根节点的下标。

合并得到的新树满足无旋 TreapTreap 的性质,同时要求以 xx 为根的子树中,所有节点的权值必须小于等于以 yy 为根的子树中任意一点的权值。

假如 xxyy 中存在至少一个 0,那么相当于其中 00 棵或 11 棵子树构成了合并出来的新树,此时直接返回 x+y 即可。

xx 的随机权值大于 yy 的随机权值,说明 xx 必须是 yy 的父节点。

又因为 yy 的点权大于 xx 的点权,所以 yy 必须是 xx 的右儿子,将 xx 的右儿子与以 yy 为根的子树合并即可。

xx 的随机权值小于等于 yy 的随机权值,说明 xx 必须是 yy 的左儿子,将以 xx 为根的子树与 yy 的左儿子合并即可。

int merge(int x, int y)
{
    if (!x || !y)
    {
        return x + y;
    }
    if (tr[x].key > tr[y].key)
    {
        tr[x].r = merge(tr[x].r, y);
        update(x);
        return x;
    }
    else
    {
        tr[y].l = merge(x, tr[y].l);
        update(y);
        return y;
    }
}

更新节点大小

TreapTreap 不同的是,无旋 TreapTreap 中点权相同的节点个数仅统计一次。

void update(int k)
{
    tr[k].size = tr[tr[k].l].size + tr[tr[k].r].size + 1;
}

新建节点

int New(int v)
{
    tr[++cnt].val = v;
    tr[cnt].size = 1;
    tr[cnt].key = rand();
    return cnt;
}

插入

插入一个权值为 valval 的节点,直接将整棵树按 valval 分裂成两棵以 xxyy 为根的子树,令新建节点的下标为 zz ,此时按顺序合并 x,z,y 即可。

void insert(int val)
{
    z = New(val);
    split(root, val, x, y);
    root = merge(merge(x, z), y);
}

删除

直接将整棵树按 valval 分裂两棵以 xzx,z 为根的子树。

再将以 xx 为根的子树按 val1val−1 分裂成两棵以 xyx,y 为根的子树。

此时,以 xx 为根的子树内所有点权均小于等于 val1val−1,也就是说,点权等于 valval 的节点都被划分到了以 yy 为根的子树内。

此时在以 yy 为根的子树内任意删除一个节点(通常选择根节点)即可,具体实现可以直接令以 yy 为根的子树为其左子树和右子树合并得到的树,比原树恰好少了一个根节点。

void del(int val)
{
    split(root, val, x, z);
    split(x, val - 1, x, y);
    y = merge(tr[y].l, tr[y].r);
    root = merge(merge(x, y), z);
}

查询排名

整棵树按 val1val−1 分裂成两棵以 xxyy 为根的子树。

此时以 xx 为根的子树中任意点权小于 valval,所以答案就是以 xx 为根的子树的大小 +1+1

int getrank(int val)
{
    split(root, val - 1, x, y);
    int ret = tr[x].size + 1;
    root = merge(x, y);
    return ret;
}

查询数值

从根节点开始查找,

  • 如果左子树的大小 +1=rk+1=rk ,说明当前节点就是要查找的数值,直接退出;

  • 如果左子树的大小 rk≥rk,说明要查找的数一定在左子树中,在左子树内继续查找;

  • 否则,要查找的树一定是右子树中排名为 rk左子树大小1rk− 左子树大小 −1 的数。

int getval(int rk)
{
    int rt = root;
    while (rt)
    {
        if (tr[tr[rt].l].size + 1 == rk)
        {
            break;
        }
        if (tr[tr[rt].l].size >= rk)
        {
            rt = tr[rt].l;
        }
        else
        {
            rk -= (tr[tr[rt].l].size + 1);
            rt = tr[rt].r;
        }
    }
    return tr[rt].val;
}

查找前驱

将整棵树按 val1val−1 分裂成两棵以 xxyy 为根的子树。

此时以 xx 为根的子树内所有点权一定都小于 valval,查找以 xx 为根的子树内最大的点权即可。

xx 开始,不断地走到右儿子,直到走到叶子节点为止。

int pre(int v)
{
    // 方法1
    split(root, v - 1, x, y);
    int rt = x;
    while (tr[rt].r)
        rt = tr[rt].r;
    root = merge(x, y);
    return tr[rt].val;
    // 方法2
    // return getval(getrank(v) - 1);
}

查找后继

将整棵树按 valval 分裂成两棵以 xxyy 为根的子树。

此时以 yy 为根的子树内所有点权一定都大于 valval,查找以 yy 为根的子树内最小的点权即可。

yy 开始,不断地走到左儿子,直到走到叶子节点为止。

int nxt(int v)
{
    // 方法1
    split(root, v, x, y);
    int rt = y;
    while (tr[rt].l)
        rt = tr[rt].l;
    root = merge(x, y);
    return tr[rt].val;
    // 方法2
    // return getval(getrank(v + 1));
}

代码实现

#include <bits/stdc++.h>
using namespace std;
#define _ (int)1e5 + 7

int n;

int root;

int cnt;

int x, y, z;

struct Tree
{
    int l, r, key, val, size;
} tr[_];

void update(int k)
{
    tr[k].size = tr[tr[k].l].size + tr[tr[k].r].size + 1;
}

int New(int v)
{
    tr[++cnt].val = v;
    tr[cnt].size = 1;
    tr[cnt].key = rand();
    return cnt;
}

void split(int rt, int val, int &x, int &y)
{
    if (!rt)
    {
        x = y = 0;
        return;
    }
    if (tr[rt].val <= val)
    {
        x = rt;
        split(tr[rt].r, val, tr[rt].r, y);
    }
    else
    {
        y = rt;
        split(tr[rt].l, val, x, tr[rt].l);
    }
    update(rt);
}

int merge(int x, int y)
{
    if (!x || !y)
    {
        return x + y;
    }
    if (tr[x].key > tr[y].key)
    {
        tr[x].r = merge(tr[x].r, y);
        update(x);
        return x;
    }
    else
    {
        tr[y].l = merge(x, tr[y].l);
        update(y);
        return y;
    }
}

void insert(int val)
{
    z = New(val);
    split(root, val, x, y);
    root = merge(merge(x, z), y);
}

void del(int val)
{
    split(root, val, x, z);
    split(x, val - 1, x, y);
    y = merge(tr[y].l, tr[y].r);
    root = merge(merge(x, y), z);
}

int getrank(int val)
{
    split(root, val - 1, x, y);
    int ret = tr[x].size + 1;
    root = merge(x, y);
    return ret;
}

int getval(int rk)
{
    int rt = root;
    while (rt)
    {
        if (tr[tr[rt].l].size + 1 == rk)
        {
            break;
        }
        if (tr[tr[rt].l].size >= rk)
        {
            rt = tr[rt].l;
        }
        else
        {
            rk -= (tr[tr[rt].l].size + 1);
            rt = tr[rt].r;
        }
    }
    return tr[rt].val;
}

int pre(int v)
{
    // 方法1
    split(root, v - 1, x, y);
    int rt = x;
    while (tr[rt].r)
        rt = tr[rt].r;
    root = merge(x, y);
    return tr[rt].val;
    // 方法2
    // return getval(getrank(v) - 1);
}

int nxt(int v)
{
    // 方法1
    split(root, v, x, y);
    int rt = y;
    while (tr[rt].l)
        rt = tr[rt].l;
    root = merge(x, y);
    return tr[rt].val;
    // 方法2
    // return getval(getrank(v + 1));
}

signed main()
{
    scanf("%d", &n);
    for (int i = 1; i <= n; ++i)
    {
        int opt, x;
        scanf("%d%d", &opt, &x);
        if (opt == 1)
        {
            insert(x);
        }
        else if (opt == 2)
        {
            del(x);
        }
        else if (opt == 3)
        {
            printf("%d\n", getrank(x));
        }
        else if (opt == 4)
        {
            printf("%d\n", getval(x));
        }
        else if (opt == 5)
        {
            printf("%d\n", pre(x));
        }
        else
        {
            printf("%d\n", nxt(x));
        }
    }
}

to be continuedto \ be \ continued

Footnotes

  1. 对于树中的任意一个节点,满足 :该节点的关键码不小于它的左子树中任意节点的关键码且不大于它的右子树中任意节点的关键码。 2 3 4 5 6

  2. 堆中每一个节点的值都必须大于等于(或小于等于)其子树中每个节点的值。 2

posted @ 2021-08-22 17:18  蒟蒻orz  阅读(37)  评论(0编辑  收藏  举报  来源