洛谷P2596 [ZJOI2006] 书架 题解 splay tree 模板题

题目链接:https://www.luogu.com.cn/problem/P2596

主要涉及的操作就是:

  • 找到某一个编号的点(这个操作可以不用 splay tree 维护)
  • 删除某个点
  • 将某一个点插入到最前面,最后面,或者某一个位置
  • 查询前序遍历为 \(k\) 的节点编号

因为每次删除都会又把这个点加回去,所以可以复用 \(n\) 个点。

好久没写 splay tree 了代码 略()丑

示例程序:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 8e4 + 5;

struct Node {
    int s[2], p, sz, id;
    Node() {}
    Node(int _p) { s[0] = s[1] = 0; p = _p; sz = 1; }

    void _init(int _p) {
        s[0] = s[1] = 0;
        p = _p;
        sz = 1;
    }

} tr[maxn];
int root, idx;

void push_up(int x) {
    auto &u = tr[x], &l = tr[u.s[0]], &r = tr[u.s[1]];
    u.sz = l.sz + 1 + r.sz;
}

void f_s(int p, int u, int k) {
    tr[p].s[k] = u;
    tr[u].p = p;
}

void rot(int x) {
    int y = tr[x].p, z = tr[y].p;
    int k = tr[y].s[1] == x;
    f_s(z, x, tr[z].s[1] == y);
    f_s(y, tr[x].s[k^1], k);
    f_s(x, y, k^1);
    push_up(y), push_up(x);
}

void splay(int x, int k) {
    while (tr[x].p != k) {
        int y = tr[x].p, z = tr[y].p;
        if (z != k)
            (tr[y].s[1]==x)^(tr[z].s[1]==y) ? rot(x) : rot(y);
        rot(x);
    }
    if (!k) root = x;
}

int n, m, P[maxn], id[maxn]; // id[s] 表示编号为 s 的书对应的节点编号

int build(int l, int r, int p) {
    if (l > r) return 0;
    int mid = (l + r) / 2, x = mid;
    tr[x] = Node(p);
    tr[x].id = P[mid];
    tr[x].s[0] = build(l, mid-1, x);
    tr[x].s[1] = build(mid+1, r, x);
    push_up(x);
    if (l == 1 && r == n) root = x;
    return x;
}

int get_k(int k) {
    int u = root;
    while (u) {
        if (tr[tr[u].s[0]].sz >= k) u = tr[u].s[0];
        else if (tr[tr[u].s[0]].sz + 1 == k) return u;
        else k -= tr[tr[u].s[0]].sz + 1, u = tr[u].s[1];
    }
    return -1;
}

void del(int x) {
    int y, z;
    splay(x, 0);
    y = tr[x].s[0];
    z = tr[x].s[1];
    if (!y || !z) {
        root = y + z;
        tr[y + z].p = 0;
        return;
    }
    while (tr[y].s[1]) y = tr[y].s[1];
    while (tr[z].s[0]) z = tr[z].s[0];
    splay(y, 0);
    splay(z, y);
    tr[z].s[0] = 0;
    f_s(y, z, 1);
    push_up(z), push_up(y);
}

void ins_tb(int p, int k, int x) {
    if (!tr[p].s[k]) {
        tr[p].s[k] = x;
        tr[x]._init(p);
    }
    else
        ins_tb(tr[p].s[k], k, x);
    push_up(p);
}

void dfs(int u) {
    if (!u)
        return;
    dfs(tr[u].s[0]);
    cout << tr[u].id << " ";
    dfs(tr[u].s[1]);
}

void test() { // 这个函数用来测试
    puts("[test]");
    dfs(root);
    puts("\n");
}

int main() {

    tr[0].s[0] = tr[0].s[1] = tr[0].p = tr[0].sz = tr[0].id = 0;

    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++) {
        scanf("%d", P+i);
        id[ P[i] ] = i;
    }
    build(1, n, 0);
    idx = n;

    char op[10];
    int s, t;
    while (m--) {
//        test();

        scanf("%s%d", op, &s);
        int x = id[s];
        if (op[0] == 'T') {         // Top
            del(x);
            ins_tb(root, 0, x);
        }
        else if (op[0] == 'B') {    // Bottom
            del(x);
            ins_tb(root, 1, x);
        }
        else if (op[0] == 'I') {    // Insert
            scanf("%d", &t);
            if (!t)
                continue;
            splay(x, 0);
            int rk = tr[ tr[x].s[0] ].sz + t;
            del(x);
            if (rk == 0) {
                ins_tb(root, 0, x);
                continue;
            }
            int y = get_k(rk);
            splay(y, 0);
            assert(tr[y].sz == n - 1);
            if (!tr[y].s[1]) {
                tr[x]._init(y);
                tr[y].s[1] = x;
                push_up(y);
            }
            else {
                int z = tr[y].s[1];
                while (tr[z].s[0]) z = tr[z].s[0];
                splay(z, y);
                assert(!tr[z].s[0]);
                tr[x]._init(z);
                tr[z].s[0] = x;
                f_s(y, z, 1);
                push_up(z), push_up(y);
            }
        }
        else if (op[0] == 'A') {    // Ask
            splay(x, 0);
            int ans = tr[ tr[x].s[0] ].sz;
            printf("%d\n", ans);
        }
        else {                      // Query
            assert(op[0] == 'Q');
            x = get_k(s);
            printf("%d\n", tr[x].id);
        }
    }

    return 0;
}
posted @ 2024-10-22 16:38  quanjun  阅读(2)  评论(0编辑  收藏  举报