二叉搜索树 & 平衡树

二叉搜索树 & 平衡树 专题

0x00 前言

AFO 了,但不代表不写 Code 了。。。

CSP-S 在数据结构上吃了大亏,就差这一点就一等了,所以觉得好好整整。

本篇博客主要研究二叉搜索树

0x01 Treap

Treap 是一种弱平衡的二叉搜索树。它同时符合二叉搜索树和堆的性质。

本文将多次提及二叉搜索树和堆的性质,所以在这里解释一下这个名词是什么意思。

二叉搜索树的性质

  • 若它的左子树不空,则左子树上所有结点的值均小于它的根结点的值;

  • 若它的右子树不空,则右子树上所有结点的值均大于它的根结点的值;

  • 它的左、右子树也分别为二叉搜索树。

堆的性质是

子节点值比父节点大 / 小。(不要理解错了,整个堆的大小关系是一致的)

接着来谈谈 Treap 怎么维护平衡的。

Treap 维护平衡的方式为旋转,在满足二叉搜索树的条件下根据堆的优先级对 Treap 进行平衡操作。

定义一颗 Treap:

struct Node
{
    int l, r;
    int key, val;
    int cnt, size;
} t[N];

push_up 操作

用于旋转和删除过后,重新计算 size 的值

void push_up(int u)
{
    t[u].size = t[t[u].l].size + t[t[u].r].size + t[u].cnt;
}

旋转

旋转操作的含义:

  • 在不影响搜索树性质的前提下,把和旋转方向相反的子树变成根节点

  • 不影响性质,并且在旋转过后,跟旋转方向相同的子节点变成了原来的根节点

偷个图片方便理解

void zig(int &p)//右旋
{
    int q = t[p].l;
    t[p].l = t[q].r;
    t[q].r = p;
    p = q;
    push_up(t[p].r);
    push_up(p);
}

void zag(int &p)//左旋
{
    int q = t[p].r;
    t[p].r = t[q].l;
    t[q].l = p;
    p = q;
    push_up(t[p].l);
    push_up(p);
}

插入

插的过程中通过旋转来维护树堆中堆的性质。

void insert(int &p, int key)
{
    if (p == 0)
    {// 没这个节点直接新建
        p = get_node(key); 
    }
    else if (t[p].key == key)
    {// 如果有这个值相同的节点,就把重复数量加一
        t[p].cnt++;
    }
    else if (t[p].key > key)
    {// 维护搜索树性质,val 比当前节点小就插到左边,反之亦然
        insert(t[p].l, key);
        if (t[t[p].l].val > t[p].val)
        {
            zig(p);
        }
    }
    else
    {
        insert(t[p].r, key);
        if (t[t[p].r].val > t[p].val)
        {
            zag(p);
        }
    }
    push_up(p);
}

删除

删完了树的大小会有变化,要注意更新。并且如果要删的节点有左子树和右子树,就要考虑删除之后让谁来当父节点。

void del(int &p, int key)
{
    if (p == 0)
    {
        return;
    }
    if (t[p].key == key)
    {
        if (t[p].cnt > 1)
        {
            t[p].cnt--;
        }
        else if (t[p].l || t[p].r)
        {
            if (!t[p].r || t[t[p].l].val > t[t[p].r].val)
            {
                zig(p);
                del(t[p].r, key);
            }
            else
            {
                zag(p);
                del(t[p].l, key);
            }
        }
        else
        {
            p = 0;
        }
    }
    else if (t[p].key > key)
    {
        del(t[p].l, key);
    }
    else
    {
        del(t[p].r, key);
    }
    push_up(p);
}

普通平衡树代码实现

题目链接

Show Code
//Treap
#include "bits/stdc++.h"

#define rint register int
#define endl '\n'

using namespace std;

const int N = 1e5 + 5;
const int inf = 1e9;

int n;
int root, idx;

struct Treap
{
    struct Node
    {
        int l, r;
        int key, val;
        int cnt, size;
    } t[N];

    void push_up(int u)
    {
        t[u].size = t[t[u].l].size + t[t[u].r].size + t[u].cnt;
    }

    int get_node(int key)
    {
        t[++idx].key = key;
        t[idx].val = rand();
        t[idx].cnt = t[idx].size = 1;
        return idx;
    }

    void zig(int &p)
    {
        int q = t[p].l;
        t[p].l = t[q].r;
        t[q].r = p;
        p = q;
        push_up(t[p].r);
        push_up(p);
    }

    void zag(int &p)
    {
        int q = t[p].r;
        t[p].r = t[q].l;
        t[q].l = p;
        p = q;
        push_up(t[p].l);
        push_up(p);
    }

    void insert(int &p, int key)
    {
        if (p == 0) p = get_node(key);
        else if (t[p].key == key) t[p].cnt++;
        else if (t[p].key > key)
        {
            insert(t[p].l, key);
            if (t[t[p].l].val > t[p].val) zig(p);
        }
        else
        {
            insert(t[p].r, key);
            if (t[t[p].r].val > t[p].val) zag(p); 
        }
        push_up(p);
    }

    void del(int &p, int key)
    {
        if (p == 0) return;
        if (t[p].key == key)
        {
            if (t[p].cnt > 1) t[p].cnt--;
            else if (t[p].l || t[p].r)
            {
                if (!t[p].r || t[t[p].l].val > t[t[p].r].val)
                {
                    zig(p), del(t[p].r, key);
                }
                else
                {
                    zag(p), del(t[p].l, key);
                }
            }
            else p = 0;
        }
        else if (t[p].key > key) del(t[p].l, key);
        else del(t[p].r, key);
        push_up(p);
    }

    int find(int p, int key)
    {
        if (p == 0) return 0;
        if (t[p].key == key) return t[t[p].l].size + 1;
        if (t[p].key > key) return find(t[p].l, key);
        return t[t[p].l].size + t[p].cnt + find(t[p].r, key);
    }

    int kth_number(int p, int rank)
    {
        if (p == 0) return inf;
        if (t[t[p].l].size >= rank) return kth_number(t[p].l, rank);
        if (t[t[p].l].size + t[p].cnt >= rank) return t[p].key;
        return kth_number(t[p].r, rank - t[t[p].l].size - t[p].cnt);
    }

    int get_prev(int p, int key)
    {
        if (p == 0) return -inf;
        if (t[p].key >= key) return get_prev(t[p].l, key);
        return max(t[p].key, get_prev(t[p].r, key));
    }

    int get_next(int p, int key)
    {
        if (p == 0) return inf;
        if (t[p].key <= key) return get_next(t[p].r, key);
        return min(t[p].key, get_next(t[p].l, key));
    }
} tree;

signed main()
{
	cin >> n;
    while (n--)
    {
        int op, x;
        cin >> op >> x;
        if (op == 1)
            tree.insert(root, x);
        if (op == 2)
            tree.del(root, x);
        if (op == 3)//不保证查询的数在原序列
            tree.insert(root, x), cout << tree.find(root, x) << endl, tree.del(root, x);
        if (op == 4)
            cout << tree.kth_number(root, x) << endl;
        if (op == 5)
            cout << tree.get_prev(root, x) << endl;
        if (op == 6)
            cout << tree.get_next(root, x) << endl;
    }

    return 0;
}


0x02 Fhq-Treap

Fhq-Treap 就是无旋 Treap。

那为什么有 Treap 了还要学 Fhq-Treap?

为了实现更多的操作,毕竟 Fhq-Treap 是比 Treap 万能多的。(但是好像啊,好像比 Treap 慢一点。。。

而且这个东西实现可持久化非常简单!

分裂

分裂一般来说有两种方法:

  • 1.按照数值分成两颗 Treap,一颗上面的值全部小于等于(或大于等于)key,另一颗相反。

  • 2.按照 size 来分裂

一般来说是采取第一种方法,在此也重点将这一种方法。
那我们怎么分裂呢?
当我们遇到了一个节点,我们判断,假若这个节点的 key<= 需要分裂的 key,那么我们把这个节点以及它的左子树放到a里面,然后往右递归。否则相反。

void split(int p, int &x, int &y, int key)
{
    if (p == 0)
    {
        x = y = 0;
        return;
    }
    if (t[p].key <= key)
    {
        x = p;
        split(t[x].r, t[x].r, y, key);
        push_up(x);
    }
    else
    {
        y = p;
        split(t[y].l, x, t[y].l, key);
        push_up(y);
    }
}

合并

我们假设要合并两个 Treap:ab,且保证 a 的所有节点的值小于 b 所有节点的值

那么每次我们合并的时候就可以先判断一下 a , bval 值的大小,使它满足堆的性质。

void merge(int &p, int x, int y)
{
    if (x == 0 || y == 0)
    {
        p = x | y;
        return;
    }
    if (t[x].val < t[y].val)
    {
        p = x;
        merge(t[p].r, t[p].r, y);
    }
    else
    {
        p = y;
        merge(t[p].l, x, t[p].l);
    }
    push_up(p);
}

普通平衡树代码实现

题目链接

Show Code
//Fhq_Treap
#include "bits/stdc++.h"

#define rint register int
#define endl '\n'

using namespace std;

const int N = 1e5 + 5;

int n;
int idx, root;

struct Fhq_Treap
{
    struct Node
    {
        int l, r;
        int val;
        int key, size;
    } t[N];
    
    int get_node(int key)
    {
        t[++idx].key = key;
        t[idx].size = 1;
        t[idx].val = rand();
        return idx;
    }

    void push_up(int p)
    {
        t[p].size = t[t[p].l].size + t[t[p].r].size + 1;
    }

    void split(int p, int &x, int &y, int key)
    {
        if (p == 0)
        {
            x = y = 0;
            return;
        }
        if (t[p].key <= key)
        {
            x = p;
            split(t[x].r, t[x].r, y, key);
            push_up(x);
        }
        else
        {
            y = p;
            split(t[y].l, x, t[y].l, key);
            push_up(y);
        }
    }

    void merge(int &p, int x, int y)
    {
        if (x == 0 || y == 0)
        {
            p = x | y;
            return;
        }
        if (t[x].val < t[y].val) p = x, merge(t[p].r, t[p].r, y);
        else p = y, merge(t[p].l, x, t[p].l);
        push_up(p);
    }

    void insert(int &p, int key)
    {
        int r1 = 0, r2 = 0, r3 = get_node(key);
        split(p, r1, r2, key);
        merge(r1, r1, r3);
        merge(p, r1, r2);
    }

    void del(int &p, int key)
    {
        int r1 = 0, r2 = 0, r3 = 0;
        split(p, r1, r2, key);
        split(r1, r1, r3, key - 1);
        merge(r3, t[r3].l, t[r3].r);
        merge(r1, r1, r3);
        merge(p, r1, r2);
    }

    int find(int &p, int key)
    {
        int r1 = 0, r2 = 0;
        split(p, r1, r2, key - 1);
        int res = t[r1].size + 1;
        merge(p, r1, r2);
        return res;
    }

    int kth_number(int p, int rank)
    {
        while (t[t[p].l].size + 1 != rank)
        {
            if (t[t[p].l].size >= rank) p = t[p].l;
            else rank -= t[t[p].l].size + 1, p = t[p].r;
        }
        return t[p].key;
    }

    int get_prev(int &p, int key)
    {
        int r1 = 0, r2 = 0;
        split(p, r1, r2, key - 1);
        int res = kth_number(r1, t[r1].size);
        merge(p, r1, r2);
        return res;
    }

    int get_next(int &p, int key)
    {
        int r1 = 0, r2 = 0;
        split(p, r1, r2, key);
        int res = kth_number(r2, 1);
        merge(p, r1, r2);
        return res;
    }
} tree;

signed main()
{
    cin >> n;
    while (n--)
    {
        int op, x;
        cin >> op >> x;
        if (op == 1)
            tree.insert(root, x);
        if (op == 2)
            tree.del(root, x);
        if (op == 3)
            tree.insert(root, x), cout << tree.find(root, x) << endl, tree.del(root, x);;
        if (op == 4)
            cout << tree.kth_number(root, x) << endl;
        if (op == 5)
            cout << tree.get_prev(root, x) << endl;
        if (op == 6)
            cout << tree.get_next(root, x) << endl;
    }

    return 0;
}


可持久化实现

题目链接

Show Code
//persistable_Fhq_Treap
#include "bits/stdc++.h"

#define rint register int
#define endl '\n'

using namespace std;

const int N = 5e5 + 5;

int idx, rt[N];

struct persistable_Fhq_Treap
{
    struct node
    {
        int l, r;
        int val;
        int key, size;
    } t[N * 50];

    int get_node(int key)
    {
        t[++idx].key = key;
        t[idx].size = 1;
        t[idx].val = rand();
        return idx;
    }

    void push_up(int p)
    {
        t[p].size = t[t[p].l].size + t[t[p].r].size + 1;
    }

    void split(int p, int &x, int &y, int key)
    {
        if (p == 0)
        {
            x = y = 0;
            return;
        }
        if (t[p].key <= key)
        {
            x = ++idx;
            t[x] = t[p];
            split(t[x].r, t[x].r, y, key);
            push_up(x);
        }
        else
        {
            y = ++idx;
            t[y] = t[p];
            split(t[y].l, x, t[y].l, key);
            push_up(y);
        }
    }

    void merge(int &p, int x, int y)
    {
        if (x == 0 || y == 0)
        {
            p = x | y;
            return;
        }
        if (t[x].val > t[y].val)
        {
            int k = ++idx;
            t[k] = t[x];
            p = x;
            merge(t[p].r, t[p].r, y);
        }
        else
        {
            int k = ++idx;
            t[k] = t[y];
            p = y;
            merge(t[p].l, x, t[p].l);
        }
        push_up(p);
    }

    void insert(int &p, int key)
    {
        int r1 = 0, r2 = 0, r3 = get_node(key);
        split(p, r1, r2, key);
        merge(r1, r1, r3);
        merge(p, r1, r2);
    }

    void del(int &p, int key)
    {
        int r1 = 0, r2 = 0, r3 = 0;
        split(p, r1, r2, key);
        split(r1, r1, r3, key - 1);
        merge(r3, t[r3].l, t[r3].r);
        merge(r1, r1, r3);
        merge(p, r1, r2);
    }

    int find(int &p, int key)
    {
        int r1 = 0, r2 = 0;
        split(p, r1, r2, key - 1);
        int res = t[r1].size + 1;
        merge(p, r1, r2);
        return res;
    }

    int kth_number(int p, int rank)
    {
        while (t[t[p].l].size + 1 != rank)
        {
            if (t[t[p].l].size >= rank) p = t[p].l;
            else rank -= t[t[p].l].size + 1, p = t[p].r;
        }
        return t[p].key;
    }

    int get_prev(int &p, int key)
    {
        int r1 = 0, r2 = 0;
        split(p, r1, r2, key - 1);
        int res = kth_number(r1, t[r1].size);
        merge(p, r1, r2);
        return res;
    }

    int get_next(int &p, int key)
    {
        int r1 = 0, r2 = 0;
        split(p, r1, r2, key);
        int res = kth_number(r2, 1);
        merge(p, r1, r2);
        return res;
    }
} tree;

signed main()
{
    int T;
    cin >> T;
    int cnt = 1;
    while (T--)
    {
        int times, op, x;
        cin >> times >> op >> x;
        rt[cnt] = rt[times];
        if (op == 1)
            tree.insert(rt[cnt], x);
        if (op == 2)
            tree.del(rt[cnt], x);
        if (op == 3)
            cout << tree.find(rt[cnt], x) << endl;
        if (op == 4)
            cout << tree.kth_number(rt[cnt], x) << endl;
        if (op == 5)
            cout << tree.get_prev(rt[cnt], x) << endl;
        if (op == 6)
            cout << tree.get_next(rt[cnt], x) << endl;
        cnt++;
    }

    return 0;
}


文艺平衡树

题目链接

这个题其实就是一个区间修改操作,加个 push_down 就好了,加个 tag 标记此点的左右儿子是否需要交换。

Show Code
//Fhq_ArtTreap
#include "iostream"
#include "cstdio"
#include "algorithm"

#define rint register int
#define endl '\n'

const int N = 1e5 + 5;

int idx, root;

struct Fhq_ArtTreap
{
    struct node
    {
        int l, r;
        int val;
        int size, key;
        bool tag;
    } t[N];

    int get_node(int key)
    {
        t[++idx].key = key;
        t[idx].size = 1;
        t[idx].val = rand();
        t[idx].tag = 0;
        return idx;
    }

    void push_up(int p)
    {
        t[p].size = t[t[p].l].size + t[t[p].r].size + 1;
    }

    void push_down(int p)
    {
        if (t[p].tag)
        {
            t[p].tag = 0;
            t[t[p].l].tag ^= 1;
            t[t[p].r].tag ^= 1;
            std::swap(t[p].l, t[p].r);
        }
    }

    void split(int p, int &x, int &y, int key)
    {
        if (p == 0)
        {
            x = y = 0;
            return;
        }
        push_down(p);
        if (t[t[p].l].size + 1 <= key)
        {
            x = p;
            split(t[x].r, t[x].r, y, key - t[t[x].l].size - 1);
            push_up(x);
        }
        else
        {
            y = p;
            split(t[y].l, x, t[y].l, key);
            push_up(y);
        }
    }

    void merge(int &p, int x, int y)
    {
        if (x == 0 || y == 0)
        {
            p = x ^ y;
            return;
        }
        if (t[x].val < t[y].val)
        {
            push_down(y);
            p = y;
            merge(t[p].l, x, t[p].l);
        }
        else
        {
            push_down(x);
            p = x;
            merge(t[p].r, t[p].r, y);
        }
        push_up(p);
    }

    void reverse(int &p, int x, int y)
    {
        int r1 = 0, r2 = 0, r3 = 0;
        split(p, r1, r2, x - 1);
        split(r2, r3, r2, y - x + 1);
        t[r3].tag ^= 1;
        merge(r2, r3, r2);
        merge(p, r1, r2);
    }

    void build(int &p, int x, int y)
    {
        if (x > y)
        {
            p = 0;
            return;
        }
        if (x == y)
        {
            p = get_node(x);
            return;
        }
        int mid = (x + y) >> 1;
        p = get_node(mid);
        build(t[p].l, x, mid - 1);
        build(t[p].r, mid + 1, y);
        push_up(p);
        t[p].val = std::max(t[t[p].l].val, t[t[p].r].val) + 1;
    }

    void dfs_print(int p)
    {
        if (p == 0)
        {
            return;
        }
        push_down(p);
        dfs_print(t[p].l);
        printf("%d ", t[p].key);
        dfs_print(t[p].r);
    }
} tree;

int n, m;

int main()
{
    scanf("%d%d", &n, &m);
    tree.build(root, 1, n);
    while (m--)
    {
        int l, r;
        scanf("%d%d", &l, &r);
        tree.reverse(root, l, r);
    }
    tree.dfs_print(root);
    return 0;
}
  


文艺平衡树可持久化实现

题目链接

不只有翻转操作了,需要多修改几个地方,主要思路不变。

Show Code
//persistable_Fhq_ArtTreap
#include "iostream"
#include "cstdio"
#include "algorithm"

#define rint register int
#define endl '\n'

const int N = 5e5 + 5;

int idx, rt[N];

struct persistable_Fhq_ArtTreap
{
    struct node
    {
        int l, r;
        int val;
        int size, key;
        bool tag;
        long long sum;
    } t[N * 50];

    int get_node(int key)
    {
        t[++idx].key = key;
        t[idx].sum = key;
        t[idx].size = 1;
        t[idx].val = rand();
        t[idx].tag = 0;
        return idx;
    }

    int copy_node(int p)
    {
        t[++idx].key = t[p].key;
        t[idx].sum = t[p].sum;
        t[idx].size = t[p].size;
        t[idx].val = t[p].val;
        t[idx].tag = t[p].tag;
        t[idx].l = t[p].l;
        t[idx].r = t[p].r;
        return idx;
    }

    void push_up(int p)
    {
        t[p].size = t[t[p].l].size + t[t[p].r].size + 1;
        t[p].sum = t[t[p].l].sum + t[t[p].r].sum + t[p].key;
    }

    void push_down(int p)
    {
        if (t[p].tag)
        {
            if (t[p].l)
            {
                t[p].l = copy_node(t[p].l);
                t[t[p].l].tag ^= 1;
            }
            if (t[p].r)
            {
                t[p].r = copy_node(t[p].r);
                t[t[p].r].tag ^= 1;
            }
            std::swap(t[p].l, t[p].r);
            t[p].tag = 0;
        }
    }

    void split(int p, int &x, int &y, int key)
    {
        if (p == 0)
        {
            x = y = 0;
            return;
        }
        int k = copy_node(p);
        push_down(k);
        if (t[t[k].l].size + 1 <= key)
        {
            x = k;
            split(t[x].r, t[x].r, y, key - t[t[x].l].size - 1);
            push_up(x);
        }
        else
        {
            y = k;
            split(t[y].l, x, t[y].l, key);
            push_up(y);
        }
    }

    void merge(int &p, int x, int y)
    {
        if (x == 0 || y == 0)
        {
            p = x ^ y;
            return;
        }
        if (t[x].val < t[y].val)
        {
            p = y;
            push_down(y);
            merge(t[p].l, x, t[p].l);
        }
        else
        {
            p = x;
            push_down(x);
            merge(t[p].r, t[p].r, y);
        }
        push_up(p);
    }

    void insert(int &p, int rank, int key)
    {
        int r1 = 0, r2 = 0;
        split(p, r1, r2, rank);
        merge(r2, get_node(key), r2);
        merge(p, r1, r2);
    }

    void del(int &p, int key)
    {
        int r1 = 0, r2 = 0, r3 = 0;
        split(p, r1, r2, key - 1);
        split(r2, r3, r2, 1);
        merge(p, r1, r2);
    }

    void reverse(int &p, int x, int y)
    {
        int r1 = 0, r2 = 0, r3 = 0;
        split(p, r1, r2, x - 1);
        split(r2, r3, r2, y - x + 1);
        t[r3].tag ^= 1;
        merge(r2, r3, r2);
        merge(p, r1, r2);
    }

    long long query(int &p, int x, int y)
    {
        int r1 = 0, r2 = 0, r3 = 0;
        split(p, r1, r2, x - 1);
        split(r2, r3, r2, y - x + 1);
        long long ans = t[r3].sum;
        merge(r2, r3, r2);
        merge(p, r1, r2);
        return ans;
    }
} tree;

int n;

int main()
{
    scanf("%d", &n);
    int cnt = 1;
    long long lastans = 0;
    while (n--)
    {
        int times, op;
        long long x, y;
        scanf("%d%d%lld", ×, &op, &x);
        x ^= lastans;
        if (op != 2)
        {
            scanf("%lld", &y);
            y ^= lastans;
        }
        rt[cnt] = rt[times];
        if (op == 1)
        {
            tree.insert(rt[cnt], x, y);
        }
        if (op == 2)
        {
            tree.del(rt[cnt], x);
        }
        if (op == 3)
        {
            tree.reverse(rt[cnt], x, y);
        }
        if (op == 4)
        {
            printf("%lld\n", lastans = tree.query(rt[cnt], x, y));
        }
        cnt++;
    }

    return 0;
}
  


0x03 替罪羊树

我们已经会了 Treap 和 Fhq-Treap,其实只需要熟练运用 Fhq-Treap 即可。

但之所以还要学替罪羊,是因为本文后面会讲到后缀平衡树且后面学 K-D Tree 是需要用到,所以在这里我们只需要理解这个数据结构的精髓并能解决模板题即可。

替罪羊树最大的特点就是暴力。

替罪羊树会将不平衡的子树进行重构来保证其平衡。

而其判断子树平衡与否就是根据刚才讲的平衡因数 alpha,只不过这里是人为设定的,称之为平衡常数。

节点

struct Node
{
    int l, r;
    int val, cnt;
    int s;//子树内节点个数。
    int sz;//子树内数据个数。
    int sd;//子树内不计删除节点的节点个数。
} t[N];

重构

替罪羊树之所以能够平衡,是在于其重构时不是瞎重构,而是将被重构的子树重构为一棵完全二叉树。

当然我们都知道这样费时又费力,更何况还是暴力重构的。

所以我们认为设定的平衡常数 alpha 在此时就起到了决定性的作用。

具体如何暴力重构就不用太多赘述了,我们可以使用简单的方法来保证线性建树,然后将新建的树接过来即可。

我们重构分两种情况:一是子树不平衡了,即左右子树之一的大小占其本身子树大小的比例超过 alpha;二是被删除的节点太多了,这样也会影响效率。

具体操作就是首先我们将需要重构的子树经中序遍历展开之后存入数组中,然后将新得到的数组二分建树。

void unfold(int &idx, int p)
{
    if (p == 0)
    {
        return;
    }
    unfold(idx, t[p].l);
    if (t[p].cnt)
    {
        q[idx++] = p;
    }
    unfold(idx, t[p].r);
}

int build(int l, int r)
{
    if (l >= r)
    {
        return 0;
    }
    int mid = (l + r) >> 1;
    t[q[mid]].l = build(l, mid);
    t[q[mid]].r = build(mid + 1, r);
    push_up(q[mid]);
    return q[mid];
}

void rebuild(int &p)
{
    int k = 0;
    unfold(k, p);
    p = build(0, k);
}

检查是否需要重构

bool check(int p)
{
    return t[p].cnt && (alpha * t[p].s <= (double)std::max(t[t[p].l].s, t[t[p].r].s) || (double)t[p].sd <= alpha * t[p].s);
}

插入

插入时,我们需要找到对应节点并 t[p].cnt++,如果没有节点就新建一个,回溯时需要判断是否能够重构,如果可以的话就重构。

void insert(int &p, int val)
{
    if (!p)
    {
        p = ++cnt;
        if (!root)
        {
            root = 1;
        }
        t[p].val = val;
        t[p].l = t[p].r = 0;
        t[p].cnt = t[p].s = t[p].sz = t[p].sd = 1;
    }
    else
    {
        if (t[p].val == val)
        {
            t[p].cnt++;
        }
        else if (t[p].val < val)
        {
            insert(t[p].r, val);
        }
        else
        {
            insert(t[p].l, val);
        }
        push_up(p);
        if (check(p))
        {
            rebuild(p);
        }
    }
}

删除

替罪羊树使用惰性删除,找到对应节点之后只需要 t[p].cnt-- 即可。当然,回溯时候遇到可以重构的节点时要重构。

void del(int &p, int val)
{
    if (p == 0)
    {
        return;
    }
    if (t[p].val == val)
    {
        if (t[p].cnt)
        {
            t[p].cnt--;
        }
    }
    else
    {
        if (t[p].val < val)
        {
            del(t[p].r, val);
        }
        else
        {
            del(t[p].l, val);
        }
    }
    push_up(p);
    if (check(p))
    {
        rebuild(p);
    }
}

upper_grade 操作

返回第一个大于其权值的位置。在查后继时用到。

int upper_grade(int p, int val)
{
    if (p == 0)
    {
        return 1;
    }
    if (t[p].val == val && t[p].cnt)
    {
        return t[t[p].l].sz + t[p].cnt + 1;
    }
    if (t[p].val > val)
    {
        return upper_grade(t[p].l, val);
    }
    return t[t[p].l].sz + t[p].cnt + upper_grade(t[p].r, val);
}

普通平衡树代码实现

题目链接

Show Code
//Scapegoat
#include "iostream"
#include "cstdio"
#include "algorithm"

#define rint register int
#define endl '\n'

const int N = 1e5 + 5;

double alpha = 0.75;

int n, root;

struct Scapegoat
{
    struct Node
    {
        int l, r;
        int val, cnt;
        int s, sz, sd;
    } t[N];

    int cnt;
    int q[N];

    void push_up(int p)
    {
        t[p].s = t[t[p].l].s + t[t[p].r].s + 1;
        t[p].sz = t[t[p].l].sz + t[t[p].r].sz + t[p].cnt;
        t[p].sd = t[t[p].l].sd + t[t[p].r].sd + (t[p].cnt != 0);
    }

    bool check(int p)
    {
        return t[p].cnt && (alpha * t[p].s <= (double)std::max(t[t[p].l].s, t[t[p].r].s) || (double)t[p].sd <= alpha * t[p].s);
    }

    void unfold(int &idx, int p)
    {
        if (p == 0)
        {
            return;
        }
        unfold(idx, t[p].l);
        if (t[p].cnt)
        {
            q[idx++] = p;
        }
        unfold(idx, t[p].r);
    }

    int build(int l, int r)
    {
        if (l >= r)
        {
            return 0;
        }
        int mid = (l + r) >> 1;
        t[q[mid]].l = build(l, mid);
        t[q[mid]].r = build(mid + 1, r);
        push_up(q[mid]);
        return q[mid];
    }

    void rebuild(int &p)
    {
        int k = 0;
        unfold(k, p);
        p = build(0, k);
    }

    void insert(int &p, int val)
    {
        if (!p)
        {
            p = ++cnt;
            if (!root)
            {
                root = 1;
            }
            t[p].val = val;
            t[p].l = t[p].r = 0;
            t[p].cnt = t[p].s = t[p].sz = t[p].sd = 1;
        }
        else
        {
            if (t[p].val == val)
            {
                t[p].cnt++;
            }
            else if (t[p].val < val)
            {
                insert(t[p].r, val);
            }
            else
            {
                insert(t[p].l, val);
            }
            push_up(p);
            if (check(p))
            {
                rebuild(p);
            }
        }
    }

    void del(int &p, int val)
    {
        if (p == 0)
        {
            return;
        }
        if (t[p].val == val)
        {
            if (t[p].cnt)
            {
                t[p].cnt--;
            }
        }
        else
        {
            if (t[p].val < val)
            {
                del(t[p].r, val);
            }
            else
            {
                del(t[p].l, val);
            }
        }
        push_up(p);
        if (check(p))
        {
            rebuild(p);
        }
    }

    int upper_grade(int p, int val)
    {
        if (p == 0)
        {
            return 1;
        }
        if (t[p].val == val && t[p].cnt)
        {
            return t[t[p].l].sz + t[p].cnt + 1;
        }
        if (t[p].val > val)
        {
            return upper_grade(t[p].l, val);
        }
        return t[t[p].l].sz + t[p].cnt + upper_grade(t[p].r, val);
    }

    int find(int p, int val)
    {
        if (p == 0)
        {
            return 0;
        }
        if (t[p].val == val && t[p].cnt)
        {
            return t[t[p].l].sz;
        }
        if (t[p].val < val)
        {
            return t[t[p].l].sz + t[p].cnt + find(t[p].r, val);
        }
        return find(t[p].l, val);
    }

    int kth_number(int p, int val)
    {
        if (p == 0)
        {
            return 0;
        }
        if (t[t[p].l].sz < val && val <= t[t[p].l].sz + t[p].cnt)
        {
            return t[p].val;
        }
        if (t[t[p].l].sz + t[p].cnt < val)
        {
            return kth_number(t[p].r, val - t[t[p].l].sz - t[p].cnt);
        }
        return kth_number(t[p].l, val);
    }

    int get_prev(int p, int val)
    {
        return kth_number(p, find(p, val));
    }

    int get_next(int p, int val)
    {
        return kth_number(p, upper_grade(p, val));
    }
} tree;

int main()
{
    scanf("%d", &n);
    while (n--)
    {
        int op, x;
        scanf("%d%d", &op, &x);
        if (op == 1)
        {
            tree.insert(root, x);
        }
        if (op == 2)
        {
            tree.del(root, x);
        }
        if (op == 3)
        {
            printf("%d\n", tree.find(root, x) + 1);
        }
        if (op == 4)
        {
            printf("%d\n", tree.kth_number(root, x));
        }
        if (op == 5)
        {
            printf("%d\n", tree.get_prev(root, x));
        }
        if (op == 6)
        {
            printf("%d\n", tree.get_next(root, x));
        }
    }

    return 0;
}
  


0x04 Splay

Splay 树, 是一种平衡二叉搜索树,它通过伸展操作不断将某个节点旋转到根节点,使得整棵树仍然满足二叉查找树的性质,能够 $log $完成插入,查找和删除操作,并且保持平衡而不至于退化为链。

旋转

为了使 Splay 保持平衡而进行旋转操作,旋转的本质是将某个节点上移一个位置。

旋转需要保证:

  • 整棵 Splay 的中序遍历不变(不能破坏二叉查找树的性质)。
  • 受影响的节点维护的信息依然正确有效。
  • root 必须指向旋转后的根节点。

在 Splay 中旋转分为两种:左旋和右旋。

具体分析旋转步骤(假设需要旋转的节点为 \(x\),其父亲为 \(y\),以右旋为例)

  • \(y\) 的左儿子指向 \(x\) 的右儿子,且 \(x\) 的右儿子(如果 \(x\) 有右儿子的话)的父亲指向 \(y\)

  • \(x\) 的右儿子指向 \(y\) ,且 \(y\) 的父亲指向 \(x\)

  • 如果原来的 \(y\) 还有父亲 \(z\) ,那么把 的某个儿子(原来 \(y\) 所在的儿子位置)指向 \(x\) ,且 \(x\) 的父亲指向 \(z\) ;

void rotate(int x)
{
    int y = t[x].fa;
    int z = t[y].fa;
    int k = t[y].s[1] == x;
    t[z].s[t[z].s[1] == y] = x;
    t[x].fa = z;
    t[y].s[k] = t[x].s[k ^ 1];
    t[t[x].s[k ^ 1]].fa = y;
    t[x].s[k ^ 1] = y;
    t[y].fa = x;
    push_up(y);
    push_up(x);
}

延伸

Splay 操作规定:每访问一个节点 \(x\) 后都要强制将其旋转到根节点。

Splay 操作即对 \(x\) 做一系列的 splay 步骤。每次对 \(x\) 做一次 splay 步骤, \(x\) 到根节点的距离都会更近。

void splay(int x, int k)
{
    while (t[x].fa != k)
    {
        int y = t[x].fa;
        int z = t[y].fa;
        if (z != k)
        {
            if ((t[y].s[1] == x) ^ (t[z].s[1] == y))
            {
                rotate(x);
            }
            else
            {
                rotate(y);
            }
        }
        rotate(x);
    }
    if (k == 0)
    {
        root = x;
    }
}

插入

插入操作是一个比较复杂的过程,具体步骤如下

假设插入的值为 \(k\)

  • 如果树空了,则直接插入根并退出。

  • 如果当前节点的权值等于 \(k\) 则增加当前节点的大小并更新节点和父亲的信息,将当前节点进行 Splay 操作。

  • 否则按照二叉查找树的性质向下找,找到空节点就插入即可

void insert(int k)
{
    if (root == 0)
    {
        t[++idx].init(k, 0);
        root = idx;
        return;
    }
    int p = root, fa = 0;
    while (1)
    {
        if (t[p].val == k)
        {
            t[p].cnt++;
            push_up(p);
            push_up(fa);
            splay(p, 0);
            break;
        }
        fa = p;
        p = t[p].s[t[p].val < k];
        if (p == 0)
        {
            t[++idx].init(k, fa);
            t[fa].s[t[fa].val < k] = idx;
            push_up(fa);
            splay(idx, 0);
            break;
        }
    }
}

普通平衡树代码实现

题目链接

Show Code
//Splay
#include "iostream"
#include "cstdio"
#include "algorithm"

#define rint register int
#define endl '\n'

const int N = 5e5 + 5;

int root, idx;

struct Splay
{
    struct Node
    {
        int s[2], fa;
        int val, cnt;
        int size; //儿子数量 
        void init(int _val, int _fa)
        {
            val = _val, fa = _fa;
            size = cnt = 1;
        }
        void clear()
        {
            s[0] = s[1] = 0;
            fa = 0;
            val = cnt = 0;
            size = 0;
        }
    } t[N];
    
    void push_up(int p)
    {
        t[p].size = t[p].cnt;
        if (t[p].s[0])
            t[p].size += t[t[p].s[0]].size;
        if (t[p].s[1])
            t[p].size += t[t[p].s[1]].size;
    }

    void rotate(int x)
    {
        int y = t[x].fa;
        int z = t[y].fa;
        int k = t[y].s[1] == x;
        t[z].s[t[z].s[1] == y] = x;
        t[x].fa = z;
        t[y].s[k] = t[x].s[k ^ 1];
        t[t[x].s[k ^ 1]].fa = y;
        t[x].s[k ^ 1] = y;
        t[y].fa = x;
        push_up(y);
        push_up(x);
    }

    void splay(int x, int k)
    {
        while (t[x].fa != k)
        {
            int y = t[x].fa;
            int z = t[y].fa;
            if (z != k)
            {
                if ((t[y].s[1] == x) ^ (t[z].s[1] == y))
                {
                    rotate(x);
                }
                else
                {
                    rotate(y);
                }
            }
            rotate(x);
        }
        if (k == 0)
        {
            root = x;
        }
    }

    void insert(int k)
    {
        if (root == 0)
        {
            t[++idx].init(k, 0);
            root = idx;
            return;
        }
        int p = root, fa = 0;
        while (1)
        {
            if (t[p].val == k)
            {
                t[p].cnt++;
                push_up(p);
                push_up(fa);
                splay(p, 0);
                break;
            }
            fa = p;
            p = t[p].s[t[p].val < k];
            if (p == 0)
            {
                t[++idx].init(k, fa);
                t[fa].s[t[fa].val < k] = idx;
                push_up(fa);
                splay(idx, 0);
                break;
            }
        }
    }

    int find(int k)
    {
        int res = 0, p = root;
        while (1)
        {
            if (k < t[p].val)
            {
                p = t[p].s[0];
            }
            else
            {
                if (t[p].s[0])
                {
                    res += t[t[p].s[0]].size;
                }
                if (k == t[p].val)
                {
                    splay(p, 0);
                    return res + 1;
                }
                res += t[p].cnt;
                p = t[p].s[1];
            }
        }
    }

    int kth_number(int k)
    {
        int p = root;
        while (1)
        {
            if (t[t[p].s[0]].size >= k && t[p].s[0])
            {
                p = t[p].s[0];
            }
            else
            {
                k -= t[p].cnt;
                if (t[p].s[0])
                {
                    k -= t[t[p].s[0]].size;
                }
                if (k <= 0)
                {
                    splay(p, 0);
                    return t[p].val;
                }
                p = t[p].s[1];
            }
        }
    }

    int prev()
    {
        int p = t[root].s[0];
        if (p == 0)
        {
            return p;
        }
        while (t[p].s[1])
        {
            p = t[p].s[1];
        }
        splay(p, 0);
        return p;
    }

    int next()
    {
        int p = t[root].s[1];
        if (p == 0)
        {
            return p;
        }
        while (t[p].s[0])
        {
            p = t[p].s[0];
        }
        splay(p, 0);
        return p;
    }
    
    int get_prev()
    {
	return t[prev()].val;
    }
	
    int get_next()
    {
	return t[next()].val;
    }

    void del(int k)
    {
        find(k);
        if (t[root].cnt > 1)
        {
            t[root].cnt--;
            push_up(root);
            return;
        }
        if (!t[root].s[0] && !t[root].s[1])
        {
            t[root].clear();
            root = 0;
            return;
        }
        if (!t[root].s[0])
        {
            int p = root;
            root = t[root].s[1];
            t[root].fa = 0;
            t[p].clear();
            return;
        }
        if (!t[root].s[1])
        {
            int p = root;
            root = t[root].s[0];
            t[root].fa = 0;
            t[p].clear();
            return;
        }
        int p = root;
        int x = prev();
        t[t[p].s[1]].fa = x;
        t[x].s[1] = t[p].s[1];
        t[p].clear();
        push_up(root);
    }
} tree;

int n;

int main()
{
	scanf("%d", &n);
    while (n--)
    {
        int op, x;
        scanf("%d%d", &op, &x);
        if (op == 1)
            tree.insert(x);
        if (op == 2)
            tree.del(x);
        if (op == 3)
            printf("%d\n", tree.find(x));
        if (op == 4)
            printf("%d\n", tree.kth_number(x));
        if (op == 5)
            tree.insert(x), printf("%d\n", tree.get_prev()), tree.del(x);
        if (op == 6)
            tree.insert(x), printf("%d\n", tree.get_next()), tree.del(x);
    }
    return 0;
}
  


文艺平衡树代码实现

题目链接

Show Code
#include "iostream"
#include "cstdio"
#include "algorithm"

#define rint register int
#define endl '\n'

const int N = 1e5 + 5;

int n, m;
int root, idx;

struct Splay
{
    struct Node
    {
        int s[2];
        int fa;
        int val;
        int size, tag;

        void init(int _val, int _fa)
        {
            val = _val, fa = _fa;
            size = 1;
        }
    } t[N];

    void push_up(int p)
    {
        t[p].size = t[t[p].s[0]].size + t[t[p].s[1]].size + 1;
    }

    void push_down(int p)
    {
        if (t[p].tag)
        {
            std::swap(t[p].s[0], t[p].s[1]);
            t[t[p].s[0]].tag ^= 1;
            t[t[p].s[1]].tag ^= 1;
            t[p].tag = 0;
        }
    }

    void rotate(int x)
    {
        int y = t[x].fa, z = t[y].fa;
        int k = t[y].s[1] == x;
        t[z].s[t[z].s[1] == y] = x, t[x].fa = z;
        t[y].s[k] = t[x].s[k ^ 1], t[t[x].s[k ^ 1]].fa = y;
        t[x].s[k ^ 1] = y, t[y].fa = x;
        push_up(y), push_up(x);
    }

    void splay(int x, int k)
    {
        while (t[x].fa != k)
        {
            int y = t[x].fa;
            int z = t[y].fa;
            if (z != k)
            {
                if ((t[y].s[1] == x) ^ (t[z].s[1] == y))
                {
                    rotate(x);
                }
                else
                {
                    rotate(y);
                }
            }
            rotate(x);
        }
        if (k == 0)
        {
            root = x;
        }
    }

    void insert(int k)
    {
        int u = root, p = 0;
        while (u)
        {
            p = u;
            u = t[u].s[k > t[u].val];
        }
        u = ++idx;
        if (p)
        {
            t[p].s[k > t[p].val] = u;
        }
        t[u].init(k, p);
        splay(u, 0);
    }

    int get_k(int k)
    {
        int u = root;
        while (1)
        {
            push_down(u);
            if (t[t[u].s[0]].size >= k)
            {
                u = t[u].s[0];
            }
            else if (t[t[u].s[0]].size + 1 == k)
            {
                return u;
            }
            else
            {
                k -= t[t[u].s[0]].size + 1;
                u = t[u].s[1];
            }
        }
        return -1;
    }

    void dfs_print(int u)
    {
        push_down(u);
        if (t[u].s[0])
        {
            dfs_print(t[u].s[0]);
        }
        if (t[u].val >= 1 && t[u].val <= n)
        {
            printf("%d ", t[u].val);
        }
        if (t[u].s[1])
        {
            dfs_print(t[u].s[1]);
        }
    }
} tree;

int main()
{
    scanf("%d%d", &n, &m);
    for (rint i = 0; i <= n + 1; i++)
    {
        tree.insert(i);
    }
    while (m--)
    {
        int l, r;
        scanf("%d%d", &l, &r);
        l = tree.get_k(l), r = tree.get_k(r + 2);
        tree.splay(l, 0), tree.splay(r, l);
        tree.t[tree.t[r].s[0]].tag ^= 1;
    }
    tree.dfs_print(root);
    return 0;
}
  


0x05 笛卡尔树

定义

定义和 Treap 相同:(Treap 是权值随机的笛卡尔树)

\(k\) 满足二叉搜索树性质

\(w\) 满足小根堆性质

洛谷模板题下标为 \(k\),元素为 \(w\)

正常构建的 Treap 应该是下标为 \(w\),元素为 \(k\)

构建

我们知道由于有堆性质,每个结点到根的链上深度从小到大第二个权值是单调递增的。

又因为 BST 性质,所以每个结点在第一个权值比其小的结点都被插入后肯定没有右子树,而且肯定不在任意一个已有结点的左子树上。

于是我们想到按第一个权值,在题目中即结点编号,从小到大每次插入一个结点,肯定会插入到从根开始一直访问右儿子直到没有右儿子为止的一条链上某个位置,满足这个位置原来的结点第二个权值大于新插入的结点的第二个权值,且原来结点的父亲结点的第二个权值小于新插入的结点的第二个权值。

因此,我们可以用单调栈来维护一个权值单调递增的下标序列,插入一个点时,如果比栈顶元素大,则将栈顶元素作为插入点的左儿子,否则将插入点作为栈顶元素的左儿子。过程结束后,栈顶元素就是这棵笛卡尔树的根。

int n;
struct Descartes_Tree
{
    int top;
    int stk[N];

    struct Node
    {
        int l, r;
    }t[N];
    
    void build(int n)
    {
        for (rint i = 1; i <= n; i++)
        { 
            while (top != 0 && a[stk[top]] > a[i])
            {
                top--;            
            }
            if (top)
            {
                t[i].l = t[stk[top]].r;
                t[stk[top]].r = i;            
            }
            else
            {
                t[i].l = stk[top + 1];            
            }
            stk[++top] = i;
        }
    }  
} tree;

代码实现

题目链接

Show Code
#include "iostream"
#include "cstdio"
#include "algorithm"

#define rint register int
#define endl '\n'
#define int long long

const int N = 1e7 + 7;

int n, a[N];
    
struct Descartes_Tree
{
    int top;
    int stk[N];

    struct Node
    {
        int l, r;
    }t[N];
    
    void build(int n)
    {
        for (rint i = 1; i <= n; i++)
        { 
            while (top != 0 && a[stk[top]] > a[i])
            {
                top--;            
            }
            if (top)
            {
                t[i].l = t[stk[top]].r;
                t[stk[top]].r = i;            
            }
            else
            {
                t[i].l = stk[top + 1];            
            }
            stk[++top] = i;
        }
    }  
    
    int query_l(int n)
    {
        int ans = 0;
        for (rint i = 1; i <= n; i++)
        {
            ans ^= i * (t[i].l + 1);
        }   
        return ans;
    }
  
    int query_r(int n)
    {
        int ans = 0;
        for (rint i = 1; i <= n; i++)
        {
            ans ^= i * (t[i].r + 1);
        }   
        return ans;
    }  
} tree;
    
signed main()
{
    scanf("%lld", &n);
    
    for (rint i = 1; i <= n; i++)
    {
        scanf("%lld", &a[i]);
    }

    tree.build(n);
    
    printf("%lld %lld", tree.query_l(n), tree.query_r(n));
    
    return 0;
}
  


应用:[TJOI2011] 树的序

题目链接

给一个生成序列,建出一棵笛卡尔树,求字典序最小的可以得到相同笛卡尔树的生成序列

按题意建好树之后输出先序遍历即可

Show Code
#include "iostream"
#include "cstdio"
#include "algorithm"

#define rint register int
#define endl '\n'
#define int long long

const int N = 8e5 + 7;

int n, a[N];
    
struct Descartes_Tree
{
    int top;
    int stk[N];

    struct Node
    {
        int l, r;
    }t[N];
    
    void build(int n)
    {
        for (rint i = 1; i <= n; i++)
        { 
            while (top != 0 && a[stk[top]] > a[i])
            {
                top--;            
            }
            if (top)
            {
                t[i].l = t[stk[top]].r;
                t[stk[top]].r = i;            
            }
            else
            {
                t[i].l = stk[top + 1];            
            }
            stk[++top] = i;
        }
    }  
    
    void inorder(int p)
    {
        if (p == 0)
        {
            return ;
        }
        if (p)
        {
            printf("%lld ", p);
        }
        inorder(t[p].l);
        inorder(t[p].r);
    }
} tree;
    
signed main()
{
    scanf("%lld", &n);
    
    for (rint i = 1; i <= n; i++)
    {
        int x;
        scanf("%lld", &x);
        a[x] = i;
    }

    tree.build(n);

    tree.inorder(tree.stk[1]);

    return 0;
}
  


0x06 线段树套平衡树

我们有时需要维护多维度信息。在这种时候,我们经常需要树套树来记录信息。

关于树套树的构建,我们对于外层线段树正常建树,对于线段树上的某一个节点,建立一棵平衡树,包含该节点所覆盖的序列。具体操作时我们可以将序列元素一个个插入,每经过一个线段树节点,就将该元素加入到该节点的平衡树中。

我们以例题 二逼平衡树 为例;

为了方便实现,我们平衡树部分就直接用 pb_ds 打包好的函数了。

(由于 pb_ds 的时空复杂度肯定比手写慢一点,所以不开 O2 会 TLE,但是一般情况下在 CCF 的数据下都是可以过的)

代码实现

Show Code
#include "bits/extc++.h"
#include "bits/stdc++.h"

#define endl '\n'

using namespace std;
using namespace __gnu_pbds;

const int N = 5e4 + 5;
const int inf = 2147483647;
int n, m, a[N];

tree, null_type, less>, rb_tree_tag, tree_order_statistics_node_update> t[N << 2];
int p[N];

void build(int u, int l, int r)
{
    for (int j = l; j <= r; j++)
    {
        t[u].insert(make_pair(a[j], j));
    }
    if (l == r)
    {
        p[l] = u;
        return;
    }
    int mid = (l + r) >> 1;
    build(u << 1, l, mid);
    build(u << 1 | 1, mid + 1, r);
}

void change(int x, int v)
{
    for (int i = p[x]; i; i >>= 1)
    {
        t[i].erase(t[i].find(make_pair(a[x], x)));
        t[i].insert(make_pair(v, x));
    }
    a[x] = v;
}

int get_rank(int u, int l, int r, int x, int y, int z)
{
    if (x <= l && y >= r)
    {
        return t[u].order_of_key(make_pair(z, 0));
    }
    int ans = 0;
    int mid = (l + r) >> 1;
    if (x <= mid)
    {
        ans = get_rank(u << 1, l, mid, x, y, z);
    }
    if (y > mid)
    {
        ans += get_rank(u << 1 | 1, mid + 1, r, x, y, z);
    }
    return ans;
}

int get_prev(int u, int l, int r, int x, int y, int z)
{
    if (x <= l && y >= r)
    {
        int v = t[u].order_of_key(make_pair(z, 0));
        if (!v)
        {
            return -inf;
        }
        return t[u].find_by_order(v - 1)->first;
    }

    int mid = (l + r) >> 1;

    if (y <= mid)
    {
        return get_prev(u << 1, l, mid, x, y, z);
    }
    if (x > mid)
    {
        return get_prev(u << 1 | 1, mid + 1, r, x, y, z);
    }

    return max(get_prev(u << 1, l, mid, x, y, z), get_prev(u << 1 | 1, mid + 1, r, x, y, z));
}

int get_next(int u, int l, int r, int x, int y, int z)
{
    if (x <= l && y >= r)
    {
        int v = t[u].order_of_key(make_pair(z, inf));
        if (v == r - l + 1)
        {
            return inf;
        }
        return t[u].find_by_order(v)->first;
    }

    int mid = (l + r) >> 1;
    if (y <= mid)
    {
        return get_next(u << 1, l, mid, x, y, z);
    }
    if (x > mid)
    {
        return get_next(u << 1 | 1, mid + 1, r, x, y, z);
    }
    return min(get_next(u << 1, l, mid, x, y, z), get_next(u << 1 | 1, mid + 1, r, x, y, z));
}

signed main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);

    cin >> n >> m;

    for (int i = 1; i <= n; i++)
    {
        cin >> a[i];
    }

    build(1, 1, n);

    while (m--)
    {
        int op;
        cin >> op;

        if (op == 3)
        {
            int x, v;
            cin >> x >> v;
            change(x, v);
            continue;
        }

        int l, r, k;
        cin >> l >> r >> k;

        if (op == 1)
        {
            cout << get_rank(1, 1, n, l, r, k) + 1;
        }
        if (op == 2)
        {
            int L = 0, R = 1e8;
            int ans = 0;
            while (L <= R)
            {
                int mid = (L + R) >> 1;
                if (get_rank(1, 1, n, l, r, mid) < k)
                {
                    ans = mid;
                    L = mid + 1;
                }
                else
                {
                    R = mid - 1;
                }
            }
            cout << ans;
        }
        if (op == 4)
        {
            cout << get_prev(1, 1, n, l, r, k);
        }
        if (op == 5)
        {
            cout << get_next(1, 1, n, l, r, k);
        }
        cout << endl;
    }
    return 0;
}
  


0x07 后缀平衡树

后缀之间的大小由字典序定义,后缀平衡树就是一个维护这些后缀顺序的平衡树,即字符串 \(T\) 的后缀平衡树是 \(T\) 所有后缀的有序集合。后缀平衡树上的一个节点相当于原字符串的一个后缀。

特别地,后缀平衡树的中序遍历即为后缀数组。

所以后缀平衡树可以解决许多字符串问题。我们使用替罪羊树实现。

原理

后缀平衡树是一个以字典序为关键字的有序字符串集 \(X\),它可以资瓷两种复杂度为 \(O(\log n)\) 的操作:

  • 向字符串集中加入一个长度为 \(1\) 的字符串 \(S\)

  • 向字符串集中加入字符串 \(xS\),其中 \(S \in X\)

所以,如果对一个字符串逆序执行插入操作,它就是严格的“后缀平衡树”,它的中序遍历也就是后缀数组。但其实上,观察这两个操作,它完全可以维护一个有根树森林。

实现

两种插入操作完全可以视为一种。我们考虑要加入一个字符串 \(xS\),我们要做的就是从根开始与当前节点进行比较,然后进入左/右儿子,直到找到一个空位为止。现在的问题就是,如何比较要插入的字符串和节点上的字符串。

显然我们不能 \(O(n)\) 暴力比较。考虑到要插入的字符串 \(xS\),除去第一个字母之后的 \(S\) 已经在树中,并且树是有序的。那么当第一个字符相同时,我们完全可以利用这棵树进行 \(O(\log n)\) 比较后面的部分,这样的话,插入操作时 \(O(\log ^ 2 n)\) 的。

考虑如何进行 \(O(1)\) 比较,来做到单次操作是 \(O(\log n)\)。我们对于每个节点维护一个 \(key\) 值代表在树中的相对位置。具体方法是:选取一个很大的值域,每次进入下一层时根据左右儿子将值域折半,最终节点的值就是当前值域的 \(mid\) 值。当需要比较两个已经在树中的字符串时,直接比较 \(key\) 值即可,复杂度 \(O(1)\)

对于如何维护平衡,最简单的方法就是使用替罪羊树。

洛谷模板题代码实现

题目链接

由于后缀平衡树维护串头操作比较容易,所以对于这种串尾操作,我们可以直接维护反串的后缀平衡树,添加和删除操作正常维护,注意删除时彻底删除节点而不是使用懒标记,这样更不容易犯错,方便后面的查询。

查询时,假设当前串为 \(S\),查询以 \(T\)\(S\) 中的出现次数,就可以转化为查询有多少个 \(S\) 的后缀是以 \(T\) 为前缀的。我们把 \(T\) 翻转,在后面添加一个字典序极大的字符 'Z'+1,查询它的排名为 \(r\)。然后我们保留后添加的字符,把上一个字符\(-1\),再次查询它的排名为 \(l\)。这样,得到的 \(r-l\) 就是 \(T\)\(S\) 中的出现次数。

Show Code
#include "iostream"
#include "cstdio"
#include "algorithm"
#include "cstring"

#define rint register int
#define endl '\n'

const int N = 8e5 + 5;
const double INF = 1e18;
const double alpha = 0.75;

int n, len;
int mask = 0;
int Q;

char a[N], h[N];
int root;

struct SuffixBST
{
    struct Node
    {
        int l, r;
        int size;
        double tag;
    } t[N];

    int idx, q[N];

    void decode(char s[], int len, int mask)
    {
        for (rint i = 0; i < len; i++)
        {
            mask = (mask * 131 + i) % len;
            std::swap(s[i], s[mask]);
        }
    }

    bool cmp(int x, int y)
    {
        if (h[x] != h[y])
        {
            return h[x] < h[y];
        }
        return t[x - 1].tag < t[y - 1].tag;
    }

    bool cmp1(char s[], int len, int p)
    {
        for (rint i = 1; i <= len; i++, p--)
        {
            if (s[i] < h[p])
            {
                return 1;
            }
            if (s[i] > h[p])
            {
                return 0;
            }
        }
        return 0;
    }

    void new_node(int &idx, int p, double l, double r)
    {
        idx = p;
        t[idx].size = 1;
        t[idx].tag = (l + r) / 2;
        t[idx].l = t[idx].r = 0;
    }

    void push_up(int p)
    {
        if (p == 0)
        {
            return;
        }
        t[p].size = t[t[p].l].size + t[t[p].r].size + 1;
    }

    bool check(int p)
    {
        return alpha * t[p].size > std::max(t[t[p].l].size, t[t[p].r].size);
    }

    void unfold(int p)
    {
        if (p == 0)
        {
            return;
        }
        unfold(t[p].l);
        q[++idx] = p;
        unfold(t[p].r);
    }

    void build(int &p, int x, int y, double l, double r)
    {
        if (x > y)
        {
            p = 0;
            return;
        }
        int mid = (x + y) >> 1;
        double mv = (l + r) / 2;
        p = q[mid];
        t[p].tag = mv;
        build(t[p].l, x, mid - 1, l, mv);
        build(t[p].r, mid + 1, y, mv, r);
        push_up(p);
    }

    void rebuild(int &p, double l, double r)
    {
        idx = 0;
        unfold(p);
        build(p, 1, idx, l, r);
    }

    void insert(int &p, int val, double l, double r)
    {
        if (p == 0)
        {
            new_node(p, val, l, r);
            return;
        }
        if (cmp(val, p))
        {
            insert(t[p].l, val, l, t[p].tag);
        }
        else
        {
            insert(t[p].r, val, t[p].tag, r);
        }
        push_up(p);
        if (!check(p))
        {
            rebuild(p, l, r);
        }
    }

    void del(int &p, int val, double l, double r)
    {
        if (p == 0)
        {
            return;
        }
        if (p == val)
        {
            if (!t[p].l || !t[p].r)
            {
                p = (t[p].l | t[p].r);
            }
            else
            {
                int k = t[p].l;
                int fa = p;
                while (t[k].r)
                {
                    fa = k;
                    t[fa].size--;
                    k = t[k].r;
                }
                if (fa == p)
                {
                    t[k].r = t[p].r;
                }
                else
                {
                    t[k].l = t[p].l;
                    t[k].r = t[p].r;
                    t[fa].r = 0;
                }
                p = k;
                t[p].tag = (l + r) / 2;
            }
        }
        else
        {
            double mv = (l + r) / 2;
            if (cmp(val, p))
            {
                del(t[p].l, val, l, mv);
            }
            else
            {
                del(t[p].r, val, mv, r);
            }
        }
        push_up(p);
        if (!check(p))
        {
            rebuild(p, l, r);
        }
    }

    int query(int p, char s[], int len)
    {
        if (p == 0)
        {
            return 0;
        }
        if (cmp1(s, len, p))
        {
            return query(t[p].l, s, len);
        }
        else
        {
            return t[t[p].l].size + query(t[p].r, s, len) + 1;
        }
    }

    void Insert(char a[])
    {
        decode(a + 1, len, mask);
        for (rint i = 1; i <= len; i++)
        {
            h[++n] = a[i];
            insert(root, n, 0, INF);
        }
    }

    void Del(int x)
    {
        while (x)
        {
            del(root, n, 0, INF);
            n--;
            x--;
        }
    }

    int Count(char a[])
    {
        decode(a + 1, len, mask);
        std::reverse(a + 1, a + len + 1);
        a[len + 1] = 'Z' + 1;
        a[len + 2] = 0;
        int ans = query(root, a, len + 1);
        a[len]--;
        ans -= query(root, a, len + 1);
        return ans;
    }
} tree;

int main()
{
    scanf("%d", &Q);
    scanf("%s", a + 1);
    len = strlen(a + 1);

    for (rint i = 1; i <= len; i++)
    {
        h[++n] = a[i];
        tree.insert(root, n, 0, INF);
    }

    for (rint i = 1; i <= Q; i++)
    {
        char op[10];
        scanf("%s", op);

        if (op[0] == 'A')
        {
            scanf("%s", a + 1);
            len = strlen(a + 1);
            tree.Insert(a);
        }

        if (op[0] == 'D')
        {
            int x;
            scanf("%d", &x);
            tree.Del(x);
        }

        if (op[0] == 'Q')
        {
            scanf("%s", a + 1);
            len = strlen(a + 1);
            int ans = tree.Count(a);
            printf("%d\n", ans);
            mask ^= ans;
        }
    }

    return 0;
}
  


posted @ 2022-11-22 17:49  PassName  阅读(44)  评论(0编辑  收藏  举报