Splay详解
因为博主太懒,所以这篇博客咕了。
【模板】普通平衡树(Splay)
#include <bits/stdc++.h>
using namespace std;
const int _ = 1e5 + 10;
const int INF = 0x3f3f3f3f;
struct node {
int fa, ch[2];
int val, cnt, siz;
} tr[_];
int root, tot = 0, N;
int alloc(int val, int fa) {
tr[++tot].fa = fa;
tr[tot].val = val;
tr[tot].siz = tr[tot].cnt = 1;
return tot;
}
inline void update(int x) {
tr[x].siz = tr[tr[x].ch[0]].siz + tr[tr[x].ch[1]].siz + tr[x].cnt;
}
inline int ident(int x) { return tr[tr[x].fa].ch[1] == x; }
inline void connect(int x, int fa, int how) {
tr[x].fa = fa;
tr[fa].ch[how] = x;
}
void rotate(int x) {
int y = tr[x].fa, z = tr[y].fa;
if (y == root) root = x;
int yson = ident(x), zson = ident(y);
int k = tr[x].ch[yson ^ 1];
connect(k, y, yson);
connect(y, x, yson ^ 1);
connect(x, z, zson);
update(y), update(x);
}
void splay(int x, int to) {
while (tr[x].fa != to) {
int y = tr[x].fa, z = tr[y].fa;
if (tr[y].fa != to)
(tr[z].ch[0] == y) ^ (tr[y].ch[0] == x) ? rotate(x) : rotate(y);
rotate(x);
}
if (!to) root = x;
}
void insert(int x) {
int u = root;
if (!u) {
root = alloc(x, 0);
return;
}
while (1) {
++tr[u].siz;
if (tr[u].val == x) {
++tr[u].cnt;
splay(u, 0);
return;
}
int nxt = x > tr[u].val;
if (!tr[u].ch[nxt]) {
int p = alloc(x, u);
tr[u].ch[nxt] = p;
splay(p, 0);
return;
}
u = tr[u].ch[nxt];
}
}
int find(int val) {
int u = root;
while (1) {
if (tr[u].val == val) {
splay(u, 0);
return u;
}
int nxt = val > tr[u].val;
if (!tr[u].ch[nxt]) return 0;
u = tr[u].ch[nxt];
}
}
void remove(int x) {
int pos = find(x);
if (!pos) return;
if (tr[pos].cnt > 1) {
--tr[pos].cnt, --tr[pos].siz;
return;
}
if (!tr[pos].ch[0] && !tr[pos].ch[1]) root = 0;
else if (!tr[pos].ch[0]) {
root = tr[pos].ch[1];
tr[root].fa = 0;
} else {
int u = tr[pos].ch[0];
while (tr[u].ch[1]) u = tr[u].ch[1];
splay(u, root);
connect(tr[pos].ch[1], u, 1);
connect(u, 0, 1);
root = u;
update(u);
}
}
int getrank(int val) {
int pos = find(val);
return tr[tr[pos].ch[0]].siz + 1;
}
int kth(int x) {
int u = root;
while (1) {
int rest = tr[u].siz - tr[tr[u].ch[1]].siz;
if (x > tr[tr[u].ch[0]].siz && x <= rest) {
splay(u, 0);
return tr[u].val;
}
if (x < rest) u = tr[u].ch[0];
else x -= rest, u = tr[u].ch[1];
}
}
int getpre(int val) {
int u = root;
int ans = -INF;
while (u) {
if (tr[u].val < val && tr[u].val > ans) ans = tr[u].val;
if (val > tr[u].val) u = tr[u].ch[1];
else u = tr[u].ch[0];
}
return ans;
}
int getnxt(int val) {
int u = root;
int ans = INF;
while (u) {
if (tr[u].val > val && tr[u].val < ans) ans = tr[u].val;
if (val >= tr[u].val) u = tr[u].ch[1];
else u = tr[u].ch[0];
}
return ans;
}
int main() {
#ifndef ONLINE_JUDGE
freopen("splay.in", "r", stdin);
freopen("splay.out", "w", stdout);
#endif
scanf("%d", &N);
while (N--) {
int op, x;
scanf("%d%d", &op, &x);
if (op == 1) insert(x);
else if (op == 2) remove(x);
else if (op == 3) printf("%d\n", getrank(x));
else if (op == 4) printf("%d\n", kth(x));
else if (op == 5) printf("%d\n", getpre(x));
else if (op == 6) printf("%d\n", getnxt(x));
}
return 0;
}
【模板】文艺平衡树(Splay区间修改)
#include <bits/stdc++.h>
using namespace std;
inline int ty() {
char ch = getchar(); int x = 0, f = 1;
while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
return x * f;
}
const int _ = 1e5 + 10;
const int INF = 0x3f3f3f3f;
struct node {
int fa, ch[2];
int val, cnt, siz, tag;
} tr[_];
int root = 1, tot = 0, N, M, num[_];
inline void update(int x) {
tr[x].siz = tr[tr[x].ch[0]].siz + tr[tr[x].ch[1]].siz + tr[x].cnt;
}
inline int ident(int x) { return tr[tr[x].fa].ch[1] == x; }
inline void connect(int x, int fa, int how) {
tr[x].fa = fa;
tr[fa].ch[how] = x;
}
inline void pushdown(int x) {
if (x && tr[x].tag) {
tr[tr[x].ch[0]].tag ^= 1;
tr[tr[x].ch[1]].tag ^= 1;
swap(tr[x].ch[0], tr[x].ch[1]);
tr[x].tag = 0;
}
}
int build(int l, int r, int fa) {
if (l > r) return 0;
int mid = (l + r) >> 1;
int cur = ++tot;
tr[cur].fa = fa;
tr[cur].val = num[mid];
++tr[cur].cnt, ++tr[cur].siz;
tr[cur].tag = 0;
tr[cur].ch[0] = build(l, mid - 1, cur);
tr[cur].ch[1] = build(mid + 1, r, cur);
update(cur);
return cur;
}
void rotate(int x) {
int y = tr[x].fa, z = tr[y].fa;
if (y == root) root = x;
int yson = ident(x), zson = ident(y);
int k = tr[x].ch[yson ^ 1];
connect(k, y, yson);
connect(y, x, yson ^ 1);
connect(x, z, zson);
update(y), update(x);
}
void splay(int x, int to) {
while (tr[x].fa != to) {
int y = tr[x].fa, z = tr[y].fa;
if (tr[y].fa != to)
(tr[z].ch[0] == y) ^ (tr[y].ch[0] == x) ? rotate(x) : rotate(y);
rotate(x);
}
if (!to) root = x;
}
int find(int x) {
int u = root;
while (1) {
pushdown(u);
int rest = tr[u].siz - tr[tr[u].ch[1]].siz;
if (x > tr[tr[u].ch[0]].siz && x <= rest) {
splay(u, 0);
return u;
}
if (x < rest) u = tr[u].ch[0];
else x -= rest, u = tr[u].ch[1];
}
}
void rever(int l, int r) {
l = find(l - 1), r = find(r + 1);
splay(l, 0);
splay(r, l);
int pos = tr[root].ch[1];
pos = tr[pos].ch[0];
tr[pos].tag ^= 1;
}
void dfs(int x) {
pushdown(x);
if (tr[x].ch[0]) dfs(tr[x].ch[0]);
if (tr[x].val != INF && tr[x].val != -INF) printf("%d ", tr[x].val);
if (tr[x].ch[1]) dfs(tr[x].ch[1]);
}
int main() {
#ifndef ONLINE_JUDGE
freopen("splay.in", "r", stdin);
freopen("splay.out", "w", stdout);
#endif
N = ty(), M = ty();
num[1] = -INF, num[N + 2] = INF;
for (int i = 1; i <= N; ++i) num[i + 1] = i;
build(1, N + 2, 0);
while (M--) {
int l = ty(), r = ty();
rever(l + 1, r + 1);
}
dfs(root);
return 0;
}
既然选择了远方,便只顾风雨兼程。