平衡树模板(AVL+FHQtreap+Splay)
题目:普通平衡树
AVL
#include <bits/stdc++.h>
using namespace std;
#define LL long long
#define MAXN 500005
#define INF INT_MAX
struct node {
LL ch[2], val, size, high;
} tr[MAXN];
LL rt, tot;
void pushup(LL rt) {
if (rt) {
tr[rt].size = tr[tr[rt].ch[0]].size + tr[tr[rt].ch[1]].size + 1;
tr[rt].high = max(tr[tr[rt].ch[0]].high, tr[tr[rt].ch[1]].high) + 1;
}
}
void zig(LL &rt) {
LL t = tr[rt].ch[0];
tr[rt].ch[0] = tr[t].ch[1];
tr[t].ch[1] = rt;
pushup(rt);
pushup(t);
rt = t;
}
void zag(LL &rt) {
LL t = tr[rt].ch[1];
tr[rt].ch[1] = tr[t].ch[0];
tr[t].ch[0] = rt;
pushup(rt);
pushup(t);
rt = t;
}
void zagzig(LL &rt) {
zag(tr[rt].ch[0]);
zig(rt);
}
void zigzag(LL &rt) {
zig(tr[rt].ch[1]);
zag(rt);
}
void maintain(LL &rt) {
LL t = 0;
if (tr[tr[rt].ch[0]].high == tr[tr[rt].ch[1]].high + 2) {
t = tr[rt].ch[0];
if (tr[tr[t].ch[0]].high == tr[tr[rt].ch[1]].high + 1) {
zig(rt);
} else if (tr[tr[t].ch[1]].high == tr[tr[rt].ch[1]].high + 1) {
zagzig(rt);
}
} else if (tr[tr[rt].ch[0]].high == tr[tr[rt].ch[1]].high - 2) {
t = tr[rt].ch[1];
if (tr[tr[t].ch[1]].high == tr[tr[rt].ch[0]].high + 1) {
zag(rt);
} else if (tr[tr[t].ch[0]].high == tr[tr[rt].ch[0]].high + 1) {
zigzag(rt);
}
}
}
void insert(LL &rt, LL x) {
if (rt == 0) {
rt = ++tot;
tr[tot].val = x;
tr[tot].high = tr[tot].size = 1;
return;
} else {
if (x < tr[rt].val) {
insert(tr[rt].ch[0], x);
} else {
insert(tr[rt].ch[1], x);
}
}
pushup(rt);
maintain(rt);
}
LL del(LL &rt, LL v) {
LL t = 0;
if ((v == tr[rt].val) || (v < tr[rt].val && tr[rt].ch[0] == 0) ||
(v >= tr[rt].val && tr[rt].ch[1] == 0)) {
if (tr[rt].ch[0] == 0 || tr[rt].ch[1] == 0) {
t = tr[rt].val;
rt = tr[rt].ch[0] + tr[rt].ch[1];
pushup(rt);
return t;
} else {
t = tr[rt].val;
tr[rt].val = del(tr[rt].ch[0], v);
}
} else {
if (v < tr[rt].val) {
t = del(tr[rt].ch[0], v);
} else {
t = del(tr[rt].ch[1], v);
}
}
pushup(rt);
maintain(rt);
return t;
}
inline LL getrank(LL rt, LL val) {
if (rt == 0)
return 1;
else if (val <= tr[rt].val)
return getrank(tr[rt].ch[0], val);
else
return getrank(tr[rt].ch[1], val) + tr[tr[rt].ch[0]].size + 1;
}
inline LL getxth(LL rt, LL val) {
if (val <= tr[tr[rt].ch[0]].size)
return getxth(tr[rt].ch[0], val);
else if (val <= tr[tr[rt].ch[0]].size + 1)
return tr[rt].val;
else
return getxth(tr[rt].ch[1], val - tr[tr[rt].ch[0]].size - 1);
}
inline LL getpre(LL rt, LL val) {
if (rt == 0)
return INT_MIN;
else if (val <= tr[rt].val)
return getpre(tr[rt].ch[0], val);
else
return max(tr[rt].val, getpre(tr[rt].ch[1], val));
}
inline LL getnxt(LL rt, LL val) {
if (rt == 0)
return INT_MAX;
else if (val >= tr[rt].val)
return getnxt(tr[rt].ch[1], val);
else
return min(tr[rt].val, getnxt(tr[rt].ch[0], val));
}
int main() {
LL n, opt, x;
scanf("%lld", &n);
while (n--) {
scanf("%lld%lld", &opt, &x);
if (opt == 1)
insert(rt, x);
else if (opt == 2)
del(rt, x);
else if (opt == 3)
printf("%lld\n", getrank(rt, x));
else if (opt == 4)
printf("%lld\n", getxth(rt, x));
else if (opt == 5)
printf("%lld\n", getpre(rt, x));
else
printf("%lld\n", getnxt(rt, x));
}
}
FHQtreap
#include <bits/stdc++.h>
using namespace std;
#define LL long long
#define MAXN 500005
#define INF 99999999
struct node {
LL ch[2], val, pri, size;
} tr[MAXN];
LL rt, root1, root2, tot;
inline void pushup(LL r) {
if (r)
tr[r].size = tr[tr[r].ch[0]].size + tr[tr[r].ch[1]].size + 1;
}
inline void split(LL rt, LL &xroot, LL &yroot, LL v) {
if (rt == 0)
xroot = yroot = 0;
else if (v < tr[rt].val) {
yroot = rt;
split(tr[rt].ch[0], xroot, tr[rt].ch[0], v);
} else {
xroot = rt;
split(tr[rt].ch[1], tr[rt].ch[1], yroot, v);
}
pushup(rt);
}
inline void merge(LL &rt, LL xroot, LL yroot) {
if (xroot == 0 || yroot == 0) {
rt = xroot + yroot;
} else if (tr[xroot].pri < tr[yroot].pri) {
rt = xroot;
merge(tr[rt].ch[1], tr[rt].ch[1], yroot);
} else {
rt = yroot;
merge(tr[rt].ch[0], xroot, tr[rt].ch[0]);
}
pushup(rt);
}
inline void insert(LL &rt, LL v) {
split(rt, root1, root2, v);
tr[++tot].val = v;
tr[tot].pri = rand();
tr[tot].size = 1;
merge(root1, root1, tot);
merge(rt, root1, root2);
}
inline void del(LL &rt, LL v) {
LL z;
split(rt, root1, root2, v);
split(root1, root1, z, v - 1);
merge(z, tr[z].ch[0], tr[z].ch[1]);
merge(rt, root1, z);
merge(rt, rt, root2);
}
inline LL getrank(LL rt, LL v) {
if (rt == 0)
return 1;
else if (v <= tr[rt].val)
return getrank(tr[rt].ch[0], v);
else
return getrank(tr[rt].ch[1], v) + tr[tr[rt].ch[0]].size + 1;
}
inline LL getxth(LL rt, LL v) {
if (v <= tr[tr[rt].ch[0]].size)
return getxth(tr[rt].ch[0], v);
else if (v <= tr[tr[rt].ch[0]].size + 1)
return tr[rt].val;
else
return getxth(tr[rt].ch[1], v - tr[tr[rt].ch[0]].size - 1);
}
int main() {
srand(time(0));
LL n;
scanf("%lld", &n);
while (n--) {
LL opt, x;
scanf("%lld%lld", &opt, &x);
if (opt == 1)
insert(rt, x);
else if (opt == 2)
del(rt, x);
else if (opt == 3)
printf("%lld\n", getrank(rt, x));
else if (opt == 4)
printf("%lld\n", getxth(rt, x));
else if (opt == 5)
printf("%lld\n", getxth(rt, getrank(rt, x) - 1));
else if (opt == 6)
printf("%lld\n", getxth(rt, getrank(rt, x + 1)));
}
}
Splay
#include <bits/stdc++.h>
using namespace std;
#define LL long long
#define MAXN 500005
#define INF INT_MAX
struct node {
LL ch[2], val, fa, cnt, size;
} tr[MAXN];
LL root, tot;
inline void pushup(LL rt) { tr[rt].size = tr[tr[rt].ch[0]].size + tr[tr[rt].ch[1]].size + tr[rt].cnt; }
void rotate(LL x) {
LL y = tr[x].fa, z = tr[y].fa;
LL k = tr[y].ch[1] == x;
tr[z].ch[tr[z].ch[1] == y] = x;
tr[x].fa = z;
tr[y].ch[k] = tr[x].ch[k ^ 1];
tr[tr[x].ch[k ^ 1]].fa = y;
tr[x].ch[k ^ 1] = y;
tr[y].fa = x;
pushup(y);
pushup(x);
}
void splay(LL x, LL goal) {
while (tr[x].fa != goal) {
LL y = tr[x].fa, z = tr[y].fa;
if (z != goal) {
(tr[y].ch[0] == x) ^ (tr[z].ch[0] == y) ? rotate(x) : rotate(y);
}
rotate(x);
}
if (goal == 0)
root = x;
}
void insert(LL x) {
LL u = root, fa = 0;
while (u && tr[u].val != x) {
fa = u;
u = tr[u].ch[x > tr[u].val];
}
if (u)
tr[u].cnt++;
else {
u = ++tot;
if (fa)
tr[fa].ch[x > tr[fa].val] = u;
tr[tot].val = x;
tr[tot].size = tr[tot].cnt = 1;
tr[tot].fa = fa;
tr[tot].ch[0] = tr[tot].ch[1] = 0;
}
splay(u, 0);
}
void find(LL x) {
LL u = root;
if (u) {
while (tr[u].ch[x > tr[u].val] && x != tr[u].val) {
u = tr[u].ch[x > tr[u].val];
}
splay(u, 0);
}
}
LL next(LL x, LL f) {
find(x);
LL u = root;
if ((x > tr[u].val && !f) || (x < tr[u].val && f))
return u;
u = tr[u].ch[f];
while (tr[u].ch[f ^ 1]) u = tr[u].ch[f ^ 1];
return u;
}
void del(LL x) {
LL la = next(x, 0);
LL ne = next(x, 1);
splay(la, 0);
splay(ne, la);
LL dele = tr[ne].ch[0];
if (tr[dele].cnt > 1) {
tr[dele].cnt--;
splay(dele, 0);
} else {
tr[ne].ch[0] = 0;
}
}
LL kth(LL x) {
LL u = root;
if (tr[u].size < x)
return false;
while (114514) {
LL y = tr[u].ch[0];
if (x > tr[y].size + tr[u].cnt) {
x -= tr[y].size + tr[u].cnt;
u = tr[u].ch[1];
} else {
if (tr[y].size >= x)
u = y;
else
return tr[u].val;
}
}
}
int main() {
LL n;
insert(-2147483647);
insert(+2147483647);
scanf("%lld", &n);
while (n--) {
LL opt, x;
scanf("%lld%lld", &opt, &x);
if (opt == 1)
insert(x);
else if (opt == 2)
del(x);
else if (opt == 3) {
find(x);
printf("%lld\n", tr[tr[root].ch[0]].size);
} else if (opt == 4)
printf("%lld\n", kth(x + 1));
else
printf("%lld\n", tr[next(x, opt - 5)].val);
}
}
作者:zswagnziye
-------------------------------------------
个性签名:独学而无友,则孤陋而寡闻。做一个灵魂有趣的人!
如果觉得这篇文章对你有小小的帮助的话,记得在右下角点个“推荐”哦,博主在此感谢!