洛谷 P3369 【模板】普通平衡树
有旋Treap模板
//pointer version
#include <bits/stdc++.h>
using namespace std;
struct Node {
Node *ch[2];
int val, rank;
int rep_cnt;
int siz;
Node(int val) : val(val), rep_cnt(1), siz(1) {
ch[0] = ch[1] = nullptr;
rank = rand();
}
void upd_siz() {
siz = rep_cnt;
if (ch[0] != nullptr) siz += ch[0]->siz;
if (ch[1] != nullptr) siz += ch[1]->siz;
}
};
enum rot_type {LF = 1, RT = 0};
void _rotate(Node *&cur, rot_type dir) {
Node *tmp = cur->ch[dir];
cur->ch[dir] = tmp->ch[!dir];
tmp->ch[!dir] = cur;
tmp->upd_siz();
cur->upd_siz();
cur = tmp;
}
void _insert(Node *&cur, int val) {
if (cur == nullptr) {
cur = new Node(val);
return;
}
else if (cur->val == val) {
cur->rep_cnt++;
cur->siz++;
return;
}
else if (cur->val > val) {
_insert(cur->ch[0], val);
if (cur->ch[0]->rank < cur->rank) {
_rotate(cur, RT);
}
cur->upd_siz();
}
else if (cur->val < val) {
_insert(cur->ch[1], val);
if (cur->ch[1]->rank > cur->rank) {
_rotate(cur, LF);
}
cur->upd_siz();
}
}
void _del(Node *&cur, int val) {
if (cur->val > val) {
_del(cur->ch[0], val);
cur->upd_siz();
}
else if (cur->val < val) {
_del(cur->ch[1], val);
cur->upd_siz();
}
else {
if (cur->rep_cnt > 1) {
cur->rep_cnt--;
cur->siz--;
return;
}
uint8_t state = 0;
state |= (cur->ch[0] != nullptr);
state |= ((cur->ch[1] != nullptr) << 1);
//00:none 01:has left 10:has right 11:both
Node *tmp = cur;
switch(state) {
case 0:
delete cur;
cur = nullptr;
break;
case 1:
cur = tmp->ch[0];
delete tmp;
cur->upd_siz();
break;
case 2:
cur = tmp->ch[1];
delete tmp;
cur->upd_siz();
break;
case 3:
rot_type dir = cur->ch[0]->rank < cur->ch[1]->rank ? RT : LF;
_rotate(cur, dir);
_del(cur->ch[!dir], val);
cur->upd_siz();
break;
}
}
}
int _query_rank(Node *&cur, int val) {
if (cur == nullptr) return 1;
int less_siz = cur->ch[0] == nullptr ? 0 : cur->ch[0]->siz;
if (val == cur->val) return less_siz + 1;
else if (cur->val > val) {
if (cur->ch[0] != nullptr) return _query_rank(cur->ch[0], val);
else return 1;
}
else {
if (cur->ch[1] != nullptr) return _query_rank(cur->ch[1], val) + less_siz + cur->rep_cnt;
else return cur->siz + 1;
}
}
int _query_val(Node *&cur, int rank) {
if (cur == nullptr) return 0;
int less_siz = cur->ch[0] == nullptr ? 0 : cur->ch[0]->siz;
if (less_siz >= rank) return _query_val(cur->ch[0], rank);
else if (less_siz + cur->rep_cnt >= rank) return cur->val;
else return _query_val(cur->ch[1], rank - less_siz - cur->rep_cnt);
}
int q_pre_tmp;
int _query_prev(Node *cur, int val) {
if (cur->val >= val) {
if (cur->ch[0] != nullptr) return _query_prev(cur->ch[0], val);
}
else {
//we update the value of q_pre_tmp, only if we entered the else branch.
q_pre_tmp = cur->val;
if (cur->ch[1] != nullptr) _query_prev(cur->ch[1], val);
return q_pre_tmp;
//we return the cur->val that entered the else branch the last time, wihch make sure that q_pre_tmp is the biggest valid value.
}
return -1;
}
int q_suf_tmp;
int _query_sufv(Node *cur, int val) {
if (cur->val <= val) {
if (cur->ch[1] != nullptr) return _query_sufv(cur->ch[1], val);
}
else {
q_suf_tmp = cur->val;
if (cur->ch[0] != nullptr) _query_sufv(cur->ch[0], val);
return q_suf_tmp;
}
return -1;
}
int main() {
int n;
cin >> n;
Node *root = nullptr;
while (n--) {
int op, x;
cin >> op >> x;
if (op == 1) _insert(root, x);
if (op == 2) _del(root, x);
if (op == 3) cout << _query_rank(root, x) << endl;
if (op == 4) cout << _query_val(root, x) << endl;
if (op == 5) cout << _query_prev(root, x) << endl;
if (op == 6) cout << _query_sufv(root, x) << endl;
}
return 0;
}
//array version
#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 1;
int n;
int rnd[N];
int val[N], cnt[N], siz[N];
int child[N][2];
int tot;
void pushup(int o) {
siz[o] = siz[child[o][0]] + siz[child[o][1]] + cnt[o];
}
void rotate(int &o, int d) {
int u = child[o][d ^ 1];
child[o][d ^ 1] = child[u][d];
child[u][d] = o;
o = u;
pushup(child[o][d]);
pushup(o);
}
void insert(int &o, int v) {
if (!o) {
val[o = ++tot] = v;
rnd[o] = rand();
cnt[o] = siz[o] = 1;
child[o][0] = child[o][1] = 0;
return;
}
if (val[o] == v) {
cnt[o]++;
pushup(o);
return;
}
else if (val[o] > v) {
insert(child[o][0], v);
if (rnd[o] < rnd[child[o][0]]) rotate(o, 1);
}
else if (val[o] < v) {
insert(child[o][1], v);
if (rnd[o] < rnd[child[o][1]]) rotate(o, 0);
}
pushup(o);
}
//here is different from BST, we keep rotating the target node until it becomes the leaf node and then delete it.
void del(int &o, int v) {
if (val[o] == v) {
if (cnt[o] > 1) {
cnt[o]--;
pushup(o);
return;
}
if (child[o][0] && child[o][1]) { //both children exist.
if (rnd[child[o][0]] > rnd[child[o][1]]) {
rotate(o, 1);
del(child[o][1], v);
}
else {
rotate(o, 0);
del(child[o][0], v);
}
}
else if (child[o][0]) { //only left child exists
rotate(o, 1);
del(child[o][1], v);
}
else if (child[o][1]) { //only right child
rotate(o, 0);
del(child[o][0], v);
}
else o = 0; //leaf node
pushup(o);
return;
}
if (val[o] > v) del(child[o][0], v);
else if (val[o] < v) del(child[o][1], v);
pushup(o);
}
int queryrank(int o, int v) {
if (!o) return 1;
if (val[o] == v) return siz[child[o][0]] + 1;
else if (val[o] > v) return queryrank(child[o][0], v);
else if (val[o] < v) return queryrank(child[o][1], v) + siz[child[o][0]] + cnt[o];
}
int querykth(int o, int k) {
if (!o) return 0;
else if (siz[child[o][0]] >= k) return querykth(child[o][0], k);
else if (siz[child[o][0]] + cnt[o] < k) return querykth(child[o][1], k - siz[child[o][0]] - cnt[o]);
return val[o];
}
int find_pre(int o, int x) {
int res = 0;
while (o) {
if (val[o] < x) res = val[o], o = child[o][1];
else o = child[o][0];
}
return res;
}
int find_suf(int o, int x) {
int res = 0;
while (o) {
if (val[o] > x) res = val[o], o = child[o][0];
else o = child[o][1];
}
return res;
}
int main () {
//freopen("D:\\1.out", "w", stdout);
int n;
int root = 0;
cin >> n;
while (n--) {
int op, x;
cin >> op >> x;
if (op == 1) insert(root, x);
if (op == 2) del(root, x);
if (op == 3) cout << queryrank(root, x) << endl;
if (op == 4) cout << querykth(root, x) << endl;
if (op == 5) cout << find_pre(root, x) << endl;
if (op == 6) cout << find_suf(root, x) << endl;
}
return 0;
}
𝓐𝓬𝓱𝓲𝓮𝓿𝓮𝓶𝓮𝓷𝓽 𝓹𝓻𝓸𝓿𝓲𝓭𝓮𝓼 𝓽𝓱𝓮 𝓸𝓷𝓵𝔂 𝓻𝓮𝓪𝓵
𝓹𝓵𝓮𝓪𝓼𝓾𝓻𝓮 𝓲𝓷 𝓵𝓲𝓯𝓮