平衡树
$Treap$实现
1 #include <bits/stdc++.h> 2 3 using namespace std; 4 5 #define re register 6 #define rep(i, a, b) for (re int i = a; i <= b; ++i) 7 #define repd(i, a, b) for (re int i = a; i >= b; --i) 8 #define maxx(a, b) a = max(a, b); 9 #define minn(a, b) a = min(a, b); 10 #define LL long long 11 #define INF (1 << 30) 12 13 inline int read() { 14 int w = 0, f = 1; char c = getchar(); 15 while (!isdigit(c)) f = c == '-' ? -1 : f, c = getchar(); 16 while (isdigit(c)) w = (w << 3) + (w << 1) + (c ^ '0'), c = getchar(); 17 return w * f; 18 } 19 20 struct Node { 21 Node* ch[2]; 22 int v, r, c, s; 23 int cmp(int x) const { 24 if (x == v) return -1; 25 return x < v ? 0 : 1; 26 } 27 void maintain() { 28 s = c + ch[0]->s + ch[1]->s; 29 } 30 }; 31 32 Node* null; 33 34 void newnode(Node* &o) { o = new Node(); o->r = rand(); o->ch[0] = o->ch[1] = null; o->v = o->s = o->c = 0; } 35 void setnode(Node* &o, int v) { o->v = v; o->s = o->c = 1; } 36 37 void rotate(Node* &o, int d) { 38 Node* k = o->ch[d^1]; o->ch[d^1] = k->ch[d]; k->ch[d] = o; 39 o->maintain(), k->maintain(); o = k; 40 } 41 42 void insert(Node* &o, int v) { 43 if (o == null) { newnode(o); setnode(o, v); return; } 44 int d = o->cmp(v); 45 if (d == -1) o->c++, o->s++; 46 else { 47 insert(o->ch[d], v); 48 if (o->ch[d]->r > o->r) rotate(o, d^1); 49 o->maintain(); 50 } 51 } 52 53 void remove(Node* &o, int v) { 54 int d = o->cmp(v); 55 if (d == -1) { 56 if (o->c) o->c--; 57 if (!o->c) { 58 if (o->ch[0] != null && o->ch[1] != null) { 59 int d2 = o->ch[0]->r < o->ch[1]->r ? 0 : 1; 60 rotate(o, d2); remove(o->ch[d2], v); 61 } 62 else { 63 Node* u = o; 64 if (o->ch[0] != null) o = o->ch[0]; else o = o->ch[1]; 65 delete u; 66 } 67 } 68 } 69 else remove(o->ch[d], v); 70 o->maintain(); 71 } 72 73 int get_rank(Node* o, int v) { 74 int d = o->cmp(v); 75 if (d == -1) return o->ch[0]->s + 1; 76 return get_rank(o->ch[d], v) + d * (o->ch[0]->s + o->c); 77 } 78 79 int get_val(Node* o, int rk) { 80 if (rk <= o->ch[0]->s) return get_val(o->ch[0], rk); 81 else if (rk > o->ch[0]->s + o->c) return get_val(o->ch[1], rk - o->ch[0]->s - o->c); 82 else return o->v; 83 } 84 85 int get_pre(Node* o, int v) { 86 if (o == null) return -INF; 87 if (o->v >= v) return get_pre(o->ch[0], v); 88 else return max(o->v, get_pre(o->ch[1], v)); 89 } 90 91 int get_next(Node* o, int v) { 92 if (o == null) return INF; 93 if (o->v <= v) return get_next(o->ch[1], v); 94 else return min(o->v, get_next(o->ch[0], v)); 95 } 96 97 Node* root; 98 99 int n; 100 101 int main() { 102 newnode(null); 103 root = null; 104 n = read(); 105 rep(i, 1, n) { 106 int opt = read(), v = read(); 107 if (opt == 1) insert(root, v); 108 if (opt == 2) remove(root, v); 109 if (opt == 3) printf("%d\n", get_rank(root, v)); 110 if (opt == 4) printf("%d\n", get_val(root, v)); 111 if (opt == 5) printf("%d\n", get_pre(root, v)); 112 if (opt == 6) printf("%d\n", get_next(root, v)); 113 } 114 return 0; 115 }
(一种较短的写法)
1 #include <bits/stdc++.h> 2 3 using namespace std; 4 5 #define re register 6 #define rep(i, a, b) for (re int i = a; i <= b; ++i) 7 #define repd(i, a, b) for (re int i = a; i >= b; --i) 8 #define For(i, a, b, s) for (re int i = a; i <= b; s) 9 #define maxx(a, b) a = max(a, b) 10 #define minn(a, b) a = min(a, b) 11 #define LL long long 12 #define INF (1 << 30) 13 14 inline int read() { 15 int w = 0, f = 1; char c = getchar(); 16 while (!isdigit(c)) f = c == '-' ? -1 : f, c = getchar(); 17 while (isdigit(c)) w = (w << 3) + (w << 1) + (c ^ '0'), c = getchar(); 18 return w * f; 19 } 20 21 struct Node { 22 Node *ch[2]; 23 int v, r, s; 24 Node(int v, Node *son) : v(v) { r = rand(); s = 1; ch[0] = ch[1] = son; } 25 void maintain() { s = ch[0]->s + ch[1]->s + 1; } 26 } *root, *null; 27 28 void init() { srand(19260817); null = new Node(0, NULL); null->ch[0] = null->ch[1] = null; null->s = 0; } 29 30 void rotate(Node* &o, int d) { Node *k = o->ch[d^1]; o->ch[d^1] = k->ch[d]; k->ch[d] = o; o->maintain(); k->maintain(); o = k; } 31 32 void insert(Node* &o, int v) { 33 if (o == null) { o = new Node(v, null); return; } 34 int d = v < o->v ? 0 : 1; insert(o->ch[d], v); if (o->ch[d]->r > o->r) rotate(o, d^1); o->maintain(); 35 } 36 37 void remove(Node* &o, int v) { 38 if (o->v != v) remove(o->ch[v < o->v ? 0 : 1], v); 39 else if (o->ch[0] != null && o->ch[1] != null) { int d = o->ch[0]->r > o->ch[1]->r ? 1 : 0; rotate(o, d); remove(o->ch[d], v); } 40 else { Node *u = o; o = o->ch[o->ch[0] == null ? 1 : 0]; delete u; } 41 if (o != null) o->maintain(); 42 } 43 44 int get_rank(Node *o, int v) { 45 int k = 0, res = INF; 46 while (o != null) { 47 if (o->v == v) minn(res, k + o->ch[0]->s + 1); 48 o = o->ch[v <= o->v ? 0 : (k += o->ch[0]->s + 1, 1)]; 49 } 50 return res; 51 } 52 53 int get_val(Node *o, int k) { 54 while (o != null) { 55 if (k == o->ch[0]->s + 1) return o->v; 56 o = o->ch[k < o->ch[0]->s + 1 ? 0 : (k -= o->ch[0]->s + 1, 1)]; 57 } 58 } 59 60 int get_pre(Node *o, int v) { 61 int res = -INF; 62 while (o != null) o = o->ch[o->v >= v ? 0 : (maxx(res, o->v), 1)]; 63 return res; 64 } 65 66 int get_next(Node *o, int v) { 67 int res = INF; 68 while (o != null) o = o->ch[o->v <= v ? 1 : (minn(res, o->v), 0)]; 69 return res; 70 } 71 72 int n, opt, v; 73 74 int main() { 75 init(); root = null; 76 n = read(); 77 rep(i, 1, n) { 78 opt = read(), v = read(); 79 if (opt == 1) insert(root, v); 80 if (opt == 2) remove(root, v); 81 if (opt == 3) printf("%d\n", get_rank(root, v)); 82 if (opt == 4) printf("%d\n", get_val(root, v)); 83 if (opt == 5) printf("%d\n", get_pre(root, v)); 84 if (opt == 6) printf("%d\n", get_next(root, v)); 85 } 86 return 0; 87 }
$fhq-treap$实现
1 #include <bits/stdc++.h> 2 3 using namespace std; 4 5 #define re register 6 #define rep(i, a, b) for (re int i = a; i <= b; ++i) 7 #define repd(i, a, b) for (re int i = a; i >= b; --i) 8 #define maxx(a, b) a = max(a, b); 9 #define minn(a, b) a = min(a, b); 10 #define LL long long 11 #define INF (1 << 30) 12 13 inline int read() { 14 int w = 0, f = 1; char c = getchar(); 15 while (!isdigit(c)) f = c == '-' ? -1 : f, c = getchar(); 16 while (isdigit(c)) w = (w << 3) + (w << 1) + (c ^ '0'), c = getchar(); 17 return w * f; 18 } 19 20 struct Node { 21 Node* ch[2]; 22 int v, s, r; 23 void maintain() { s = 1 + ch[0]->s + ch[1]->s; } 24 Node(int v, Node* son) : v(v) { ch[0] = ch[1] = son; s = 1; r = rand(); } 25 } *null, *root; 26 27 void initnull() { null = new Node(0, NULL); null->ch[0] = null->ch[1] = null; null->s = 0; } 28 29 Node* merge(Node* a, Node* b) { 30 if (a == null) return b; 31 if (b == null) return a; 32 if (a->r > b->r) { a->ch[1] = merge(a->ch[1], b); a->maintain(); return a; } 33 b->ch[0] = merge(a, b->ch[0]); b->maintain(); return b; 34 } 35 36 void split_rank(Node *o, int k, Node* &l, Node* &r) { 37 if (o == null) { l = r = null; return; } 38 if (o->ch[0]->s + 1 <= k) { l = o; split_rank(l->ch[1], k - o->ch[0]->s - 1, l->ch[1], r); } 39 else { r = o; split_rank(r->ch[0], k, l, r->ch[0]); } 40 o->maintain(); 41 } 42 43 void split_val(Node *o, int v, Node* &l, Node* &r) { 44 if (o == null) { l = r = null; return; } 45 if (o->v <= v) { l = o; split_val(l->ch[1], v, l->ch[1], r); } 46 else { r = o; split_val(r->ch[0], v, l, r->ch[0]); } 47 o->maintain(); 48 } 49 50 void insert(Node* &root, int v) { 51 Node *l, *r, *mid = new Node(v, null); 52 split_val(root, v, l, r); 53 root = merge(l, merge(mid, r)); 54 } 55 56 void remove(Node* &root, int v) { 57 Node *l, *mid, *r; 58 split_val(root, v-1, l, r); 59 split_val(r, v, mid, r); 60 root = merge(l, merge(merge(mid->ch[0], mid->ch[1]), r)); 61 delete mid; 62 } 63 64 int get_val(Node *o, int rk) { 65 if (o->ch[0]->s + 1 < rk) return get_val(o->ch[1], rk - o->ch[0]->s - 1); 66 if (o->ch[0]->s + 1 > rk) return get_val(o->ch[0], rk); 67 return o->v; 68 } 69 70 int get_rank(Node *o, int v) { 71 if (o == null) return INF; 72 if (o->v == v) return min(o->ch[0]->s + 1, get_rank(o->ch[0], v)); 73 if (v < o->v) return get_rank(o->ch[0], v); 74 return get_rank(o->ch[1], v) + o->ch[0]->s + 1; 75 } 76 77 int get_pre(Node *o, int v) { 78 if (o == null) return -INF; 79 if (o->v < v) return max(o->v, get_pre(o->ch[1], v)); 80 return get_pre(o->ch[0], v); 81 } 82 83 int get_next(Node *o, int v) { 84 if (o == null) return INF; 85 if (o->v > v) return min(o->v, get_next(o->ch[0], v)); 86 return get_next(o->ch[1], v); 87 } 88 89 int n, opt, v; 90 91 int main() { 92 initnull(); 93 root = null; 94 n = read(); 95 rep(i, 1, n) { 96 opt = read(), v = read(); 97 if (opt == 1) insert(root, v); 98 if (opt == 2) remove(root, v); 99 if (opt == 3) printf("%d\n", get_rank(root, v)); 100 if (opt == 4) printf("%d\n", get_val(root, v)); 101 if (opt == 5) printf("%d\n", get_pre(root, v)); 102 if (opt == 6) printf("%d\n", get_next(root, v)); 103 } 104 return 0; 105 }