洛谷P1486 [NOI2004] 郁闷的出纳员 题解 splay tree
题目链接:https://www.luogu.com.cn/problem/P1486
涉及操作:
- 插入一个点
- 删除左子树(先找到 \(\ge m\) 的最小的点然后 splay 到根节点)(注:这里我用 \(m\) 表示薪资的下限)
- 找第 \(k\) 大的节点的权值(需要注意的是:某一个权值可能具有多个点)
示例程序:
#include <bits/stdc++.h>
using namespace std;
const int maxn = 3e5 + 5;
int n, m;
struct Node {
int s[2], p; // s[0] 左儿子 s[1] 右儿子 p 父节点
int num, sz; // num 当前节点个数,sz 子树节点个数
int v, flag; // v 权值(工资),flag 懒惰标记(变换的工资)
Node() {}
Node(int _v, int _p) {s[0] = s[1] = sz = num = flag = 0; v = _v; p = _p;}
} tr[maxn];
int root, idx;
void push_up(int x) {
tr[x].sz = tr[x].num + tr[tr[x].s[0]].sz + tr[tr[x].s[1]].sz;
}
void t_flag(int x, int tmp) {
if (x) {
tr[x].v += tmp;
tr[x].flag += tmp;
}
}
void push_down(int x) {
if (tr[x].flag) {
t_flag(tr[x].s[0], tr[x].flag);
t_flag(tr[x].s[1], tr[x].flag);
tr[x].flag = 0;
}
}
void f_s(int p, int u, bool k) {
if (p) tr[p].s[k] = u;
tr[u].p = p;
}
void rot(int x) {
int y = tr[x].p, z = tr[y].p;
bool k = tr[y].s[1] == x;
f_s(z, x, tr[z].s[1]==y);
f_s(y, tr[x].s[k^1], k);
f_s(x, y, k^1);
push_up(y), push_up(x);
}
void splay(int x, int k) {
while (tr[x].p != k) {
int y = tr[x].p, z = tr[y].p;
if (z != k)
(tr[y].s[1]==x) ^ (tr[z].s[1]==y) ? rot(x) : rot(y);
rot(x);
}
if (!k) root = x;
}
void ins(int v) {
int u = root, p = 0;
while (u && tr[u].v != v) {
push_down(u);
p = u, u = tr[u].s[v > tr[u].v];
}
if (!u) {
tr[u = ++idx] = Node(v, p);
if (p) tr[p].s[v > tr[p].v] = u;
}
else
push_down(u);
tr[u].num++;
push_up(u);
splay(u, 0);
}
char op[2];
int k, cnt;
void ss() {
int u = root, p, x = 0;
while (u) {
p = u;
push_down(u);
if (tr[u].v >= m) x = u, u = tr[u].s[0];
else u = tr[u].s[1];
}
if (!x) tr[root = 0] = Node(0, 0);
else {
splay(x, 0);
tr[x].s[0] = 0;
push_up(x);
}
}
int get_k(int k) {
int u = root;
while (u) {
push_down(u);
if (tr[tr[u].s[1]].sz >= k) u = tr[u].s[1];
else if (tr[tr[u].s[1]].sz + tr[u].num >= k) return u;
else k -= tr[tr[u].s[1]].sz + tr[u].num, u = tr[u].s[0];
}
return -1;
}
int main() {
scanf("%d%d", &n, &m);
while (n--) {
scanf("%s%d", op, &k);
if (op[0] == 'I') {
if (k >= m)
cnt++, ins(k);
}
else if (op[0] == 'A') {
if (root)
tr[root].v += k, tr[root].flag += k;
}
else if (op[0] == 'S') {
if (root) {
tr[root].v -= k, tr[root].flag -= k;
ss();
}
}
else {
int u = get_k(k);
if (u == -1) puts("-1");
else printf("%d\n", tr[u].v);
}
}
printf("%d\n", cnt - tr[root].sz);
return 0;
}