LibreOJ104 - 普通平衡树 (平衡树)

Description

这是一道模板题。

您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:

  1. 插入 x 数;
  2. 删除 x 数(若有多个相同的数,因只删除一个);
  3. 查询 x 数的排名(若有多个相同的数,因输出最小的排名);
  4. 查询排名为 x 的数;
  5. 求 x 的前趋(前趋定义为小于 x,且最大的数);
  6. 求 x 的后继(后继定义为大于 x,且最小的数)。

思路

大模板题,写完非常爽。
用无旋treap实现

#include <iostream>
#include <cstdio>
#include <queue>
#include <algorithm>
#include <map>
#include <set>
#include <vector>
#include <cstring>
#include <string>
#include <stack>
#include <deque>
#include <cmath>
#include <iomanip>
#include <cctype>

#define endl '\n'
#define IOS                       \
    std::ios::sync_with_stdio(0); \
    cout.tie(0);                  \
    cin.tie(0);
#define FILE freopen("..//data_generator//in.txt", "r", stdin), freopen("res.txt", "w", stdout)
#define FI freopen("..//data_generator//in.txt", "r", stdin)
#define FO freopen("res.txt", "w", stdout)
#define pb cpush_back
#define mp make_pair
#define seteps(N) fixed << setprecision(N)

typedef long long ll;
using namespace std;
/*-----------------------------------------------------------------*/

#define INF 0x3f3f3f3f
const int N = 1e5 + 10;
const int M = 1e9 + 7;
const double eps = 1e-8;

int pos[N];
int lc[N], rc[N];
int cnt[N];
int cur[N];
int val[N];
int si;

typedef pair<int, int> PII;

int find(int rt, int v) {
    if (!rt)
        return rt;
    if (v > val[rt]) {
        return find(rc[rt], v);
    } else if (v < val[rt]) {
        return find(lc[rt], v);
    } else
        return rt;
}

int getrk(int rt, int v) {
    if (!rt)
        return 1;
    if (v > val[rt]) {
        return getrk(rc[rt], v) + cnt[lc[rt]] + cur[rt];
    } else if (v < val[rt]) {
        return getrk(lc[rt], v);
    } else
        return cnt[lc[rt]] + 1;
}

int getv(int rt, int rk) {
    if (rk <= cnt[lc[rt]])
        return getv(lc[rt], rk);
    if (rk - cnt[lc[rt]] <= cur[rt])
        return val[rt];
    return getv(rc[rt], rk - cur[rt] - cnt[lc[rt]]);
}

void update(int rt, int v, int x) {
    if (v > val[rt]) {
        update(rc[rt], v, x);
    } else if (v < val[rt]) {
        update(lc[rt], v, x);
    } else {
        cur[rt] += x;
    }
    cnt[rt] += x;
}

PII split(int rt, int key) {
    if (!rt) {
        return mp(0, 0);
    }
    if (key < val[rt]) {
        PII o = split(lc[rt], key);
        lc[rt] = o.second;
        cnt[rt] = cnt[lc[rt]] + cnt[rc[rt]] + cur[rt];
        return mp(o.first, rt);
    } else {
        PII o = split(rc[rt], key);
        rc[rt] = o.first;
        cnt[rt] = cnt[lc[rt]] + cnt[rc[rt]] + cur[rt];
        return mp(rt, o.second);
    }
}

int merge(int lrt, int rrt) {
    if (!lrt)
        return rrt;
    if (!rrt)
        return lrt;
    if (pos[lrt] > pos[rrt]) {
        rc[lrt] = merge(rc[lrt], rrt);
        cnt[lrt] = cnt[lc[lrt]] + cnt[rc[lrt]] + cur[lrt];
        return lrt;
    } else {
        lc[rrt] = merge(lrt, lc[rrt]);
        cnt[rrt] = cnt[lc[rrt]] + cnt[rc[rrt]] + cur[rrt];
        return rrt;
    }
}

void insert(int& rt, int v) {
    if (find(rt, v)) {
        update(rt, v, 1);
    } else {
        si++;
        pos[si] = rand();
        cnt[si] = cur[si] = 1;
        val[si] = v;
        PII o = split(rt, v);
        rt = merge(merge(o.first, si), o.second);
    }
}

void erase(int& rt, int v) {
    int p = find(rt, v);
    if (!p)
        return ;
    update(rt, v, -1);
    if (cur[p] == 0) {
        PII o = split(rt, v);
        rt = merge(split(o.first, v - 1).first, o.second);
    }
}

void print(int rt) {
    if(!rt) return ;
    print(lc[rt]);
    cout << val[rt] << " ";
    print(rc[rt]);
}

int main() {
    IOS;
    //FILE;
    int rt = 0;
    int n;
    cin >> n;
    while (n--) {
        int opt;
        cin >> opt;
        int x;
        cin >> x;
        if (opt == 1) {
            insert(rt, x);
        } else if (opt == 2) {
            erase(rt, x);
        } else if (opt == 3) {
            cout << getrk(rt, x) << endl;
        } else if (opt == 4) {
            cout << getv(rt, x) << endl;
        } else if (opt == 5) {
            int rk = getrk(rt, x);
            cout << getv(rt, rk - 1) << endl;
        } else if (opt == 6) {
            int rk = getrk(rt, x + 1);
            cout << getv(rt, rk) << endl;
        }
    }
    //print(rt);
}
posted @ 2020-06-27 13:12  limil  阅读(152)  评论(0编辑  收藏  举报