洛谷P8767 [蓝桥杯 2021 国 A] 冰山 题解 splay tree
题目链接:https://www.luogu.com.cn/problem/P8767
鸣谢:这道题的顺利解决得到了 7KByte 大佬的大力帮助,在此再次表示感谢。
首先,我的想法是这样的:
使用一个 splay tree 来维护这些冰山的信息。
对于每次操作的 \(x\) 和 \(y\):
- 先将所有节点的权值增加 \(x\)(这步操作只需要修改根节点的信息,然后进行一下懒惰标记即可,以后会 push_down 下去的)
- 然后增加一个权值为 \(y\) 的点(如果权值为 \(y\) 的点本身就存在的话,就将其数量增加 \(1\))
- 先将权值 \(\gt 0\) 且最小的节点 splay 为根节点,然后删除其左子树(因为此时根节点的左儿子的权值均 \(\le 0\))
- 再将权值 \(\le k\) 且最大的节点 splay 为根节点,此时其右子树对应的节点都是权值 \(\gt k\) 的,这些节点需要拆解为若干个权值为 \(k\) 的节点和若干个权值为 \(1\) 的节点,操作过程为:
- 首先需要为每一个节点维护两个信息:
- sum 表示以该节点为根节点的子树中所有节点权值与数量的乘积之和,比如:如果该节点所在子树中有 \(2\) 个权值为 \(3\) 的点,\(6\) 个权值为 \(5\) 的点,则该节点的 sum 为 \(2 \times 3 + 6 \times 5 = 36\)
- cnt 表示以该节点为根节点的自述中所有节点的数量,比如:如果该绩点所在子树中有 \(2\) 个权值为 \(3\) 的点,\(6\) 个权值为 \(5\) 的点,则该节点的 cnt 值为 \(2 + 6 = 8\)
- 然后,先计算出根节点的右儿子的 sum 值和 cnt 值,可以发现,这些节点中的每一个都会被分成 \(1\) 个权值为 \(k\) 的节点以及(剩余的)若干个权值为 \(1\) 的节点,所以我们需要做的事情是:
- 删除根节点的右子树
- 插入 cnt 个权值为 \(k\) 的节点
- 插入
sum - cnt * k
个权值为 \(1\) 的节点
- 首先需要为每一个节点维护两个信息:
(注:这样讨论在下面的实现是还有一些细节没有处理,我们在下面的内容中讨论)
具体实现时,首先由于我经常会用到变量小 k
,所以我把题目描述中的 k
开成了大K(即 K
)—— 即冰山的最大大小。
然后我定义了一个结构体来维护 splay tree 里面的节点信息,如下:
struct Node {
int s[2], p; // s[0] 左儿子 s[1] 右儿子 p 父节点
long long v, // 冰山体积
num, // 冰山个数
cnt, // 子树包含冰山个数
sum, // 子树包含冰山体积之和
flag; // 懒惰标记
Node() {};
Node(long long _v, int _p) {v = _v; p = _p; s[0] = s[1] = 0; num = cnt = sum = flag = 0;}
} tr[maxn];
其中:
s[0]
和s[1]
分别表示左儿子和右儿子节点的编号(如果没有则为 \(0\))v
表示当前节点的权值(对应的就是冰山的体积)num
表示体积为v
的冰山有多少个cnt
表示以当前节点为根节点的子树中包含的冰山的总数sum
表示以当前节点为根节点的子树中包含的冰山的总体积flag
是懒惰标记,它表示当前节点的所有子节点的体积需要整体增加的量
这里需要注明的是,我自己做懒惰标记的习惯是:修改当前节点信息的同时进行懒惰标记,然后将懒惰标记传给子节点的时候也会修改子节点的信息,因为写线段树的时候都是这么写的就习惯了。(因为有的大佬可能习惯是 pushdown 的时候更新当前节点,但是我已经习惯更新当前节点并进行懒惰标记,然后 pushdown 的时候一方面更新子节点,另一方面把懒惰标记传给子节点)
然后因为具体实现是经常需要进行形如 a = (a + b) % MOD;
的操作,所以我添加了一个 add
函数方便写:
void add(long long &a, long long b) {
a = (a + b % MOD) % MOD;
}
然后就是比较重要的 push up 和 push down 操作了。
push up
push_up 需要将子节点的信息更新到当前节点,只需要更新一下 cnt 和 sum
主要操作就是:
tr[x].cnt = tr[x].num + tr[tr[x].s[0]].cnt + tr[tr[x].s[1]].cnt;
tr[x].sum = tr[x].num * tr[x].v + tr[tr[x].s[0]].sum + tr[tr[x].s[1]].sum;
需要取一下模。
push down
push_down 需要将当前节点的懒惰标记(flag,对应的是冰山体积的整体增量)传递给子节点并且同时更新子节点。对于一个节点来说,当传递了一个值为 tmp
的体积增量时:
- flag 会增加 tmp
- v 会增加 tmp
- sum 会增加
cnt * tmp
(因为该节点的子树中所有节点对应的冰山体积都会增加tmp
)
主要操作是:
void t_flag(int x, long long tmp) {
if (x) {
tr[x].flag += tmp;
tr[x].v += tmp;
add(tr[x].sum, tr[x].cnt * tmp);
}
}
void push_down(int x) {
if (tr[x].flag) {
t_flag(tr[x].s[0], tr[x].flag);
t_flag(tr[x].s[1], tr[x].flag);
tr[x].flag = 0;
}
}
旋转和 splay 操作
这部分的操作基本没有改动过,之前用它们解决过 AcWing 上的 splay 例题(原来的帖子:https://www.acwing.com/file_system/file/content/whole/index/content/7428637/)
对应的代码(多了一个 f_s(p, u, k) 函数用来认亲(u 是 p 的儿子,其中 k = 0 表示 u 是 p 的左儿子;k = 1 表示 u 是 p 的右儿子):
void f_s(int p, int u, bool k) {
tr[p].s[k] = u;
tr[u].p = p;
}
void rot(int x) {
int y = tr[x].p, z = tr[y].p;
bool k = tr[y].s[1] == x;
f_s(z, x, tr[z].s[1]==y);
f_s(y, tr[x].s[k^1], k);
f_s(x, y, k^1);
push_up(y), push_up(x);
}
void splay(int x, int k) {
while (tr[x].p != k) {
int y = tr[x].p, z = tr[y].p;
if (z != k)
(tr[y].s[1]==x) ^ (tr[z].s[1]==y) ? rot(x) : rot(y);
rot(x);
}
if (!k) root = x;
}
插入操作
ins(v, num)
表示的是加入 num
个体积为 v
的冰山:
void ins(long long v, long long num) {
int u = root, p = 0;
while (u && tr[u].v != v) {
push_down(u);
p = u, u = tr[u].s[v > tr[u].v];
}
if (!u) {
tr[u = ++idx] = Node(v, p);
if (p) tr[p].s[v > tr[p].v] = u;
}
else
push_down(u); // 如果不是新建的节点,需要push_down一下
add(tr[u].num, num);
add(tr[u].cnt, num);
add(tr[u].sum, num * v);
splay(u, 0);
}
调整
这里有一个 check 函数,主要用来判断当前节点对应的冰山体积是否合法:
bool check(int u) {
return tr[u].v > 0 && tr[u].v <= K;
}
然后就是 get1()
函数和 get2()
函数了。
get1() 函数对应的就是我上面说的第 3 步操作:
- 先将权值 \(\gt 0\) 且最小的节点 splay 为根节点,然后删除其左子树(因为此时根节点的左儿子的权值均 \(\le 0\))
补充一点细节: 如果循环结束时没有找到 \(\gt 0\) 的节点(即下面代码中的 x
结束时仍然为 \(0\)),说明需要将整棵树删除。
void get1() {
int u = root, p = 0, x = 0;
while (u) {
push_down(u);
p = u;
if (tr[u].v > 0) {
x = u;
u = tr[u].s[0];
}
else u = tr[u].s[1];
}
if (x) {
splay(x, 0);
tr[x].s[0] = 0;
push_up(x);
}
else tr[root = 0] = Node(0, 0);
}
get2() 函数对应的就是我上面说的第 4 步操作:
- 再将权值 \(\le k\) 且最大的节点 splay 为根节点,此时其右子树对应的节点都是权值 \(\gt k\) 的,这些节点需要拆解为若干个权值为 \(k\) 的节点和若干个权值为 \(1\) 的节点
补充一点细节: 如果循环结束时没有找到 \(\gt 0\) 的节点(即下面代码中的 x
结束时仍然为 \(0\)),说明要么树是空的,要么所有节点都是 \(\gt K\) 的,如果所有节点都是 \(\gt K\) 的,应该处理整棵树,而不是根节点的右儿子。
void get2() {
int u = root, p = 0, x = 0;
while (u) {
push_down(u);
p = u;
if (tr[u].v <= K) {
x = u;
u = tr[u].s[1];
}
else u = tr[u].s[0];
}
int y;
if (x) splay(x, 0), y = tr[x].s[1];
else y = root;
if (y) {
long long cnt = tr[y].cnt, sum = tr[y].sum;
if (y != root) tr[x].s[1] = 0;
else tr[root = 0] = Node(0, 0);
push_up(root);
ins(K, cnt);
ins(1, (sum - cnt * K % MOD + MOD) % MOD);
}
}
完整代码如下:
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e6 + 5;
const long long MOD = 998244353;
int n, m;
long long K;
struct Node {
int s[2], p; // s[0] 左儿子 s[1] 右儿子 p 父节点
long long v, // 冰山体积
num, // 冰山个数
cnt, // 子树包含冰山个数
sum, // 子树包含冰山体积之和
flag; // 懒惰标记
Node() {};
Node(long long _v, int _p) {v = _v; p = _p; s[0] = s[1] = 0; num = cnt = sum = flag = 0;}
} tr[maxn];
int root, idx;
void add(long long &a, long long b) {
a = (a + b % MOD) % MOD;
}
void push_up(int x) {
tr[x].cnt = tr[x].num + tr[tr[x].s[0]].cnt + tr[tr[x].s[1]].cnt;
tr[x].cnt %= MOD;
tr[x].sum = tr[x].num * tr[x].v + tr[tr[x].s[0]].sum + tr[tr[x].s[1]].sum;
tr[x].sum %= MOD;
tr[x].sum = (tr[x].sum + MOD) % MOD;
}
void t_flag(int x, long long tmp) {
if (x) {
tr[x].flag += tmp;
tr[x].v += tmp;
add(tr[x].sum, tr[x].cnt * tmp);
}
}
void push_down(int x) {
if (tr[x].flag) {
t_flag(tr[x].s[0], tr[x].flag);
t_flag(tr[x].s[1], tr[x].flag);
tr[x].flag = 0;
}
}
void f_s(int p, int u, bool k) {
if (p) tr[p].s[k] = u;
tr[u].p = p;
}
void rot(int x) {
int y = tr[x].p, z = tr[y].p;
bool k = tr[y].s[1] == x;
f_s(z, x, tr[z].s[1]==y);
f_s(y, tr[x].s[k^1], k);
f_s(x, y, k^1);
push_up(y), push_up(x);
}
void splay(int x, int k) {
while (tr[x].p != k) {
int y = tr[x].p, z = tr[y].p;
if (z != k)
(tr[y].s[1]==x) ^ (tr[z].s[1]==y) ? rot(x) : rot(y);
rot(x);
}
if (!k) root = x;
}
void ins(long long v, long long num) {
int u = root, p = 0;
while (u && tr[u].v != v) {
push_down(u);
p = u, u = tr[u].s[v > tr[u].v];
}
if (!u) {
tr[u = ++idx] = Node(v, p);
if (p) tr[p].s[v > tr[p].v] = u;
}
else
push_down(u); // 如果不是新建的节点,需要push_down一下
add(tr[u].num, num);
add(tr[u].cnt, num);
add(tr[u].sum, num * v);
splay(u, 0);
}
bool check(int u) {
return tr[u].v > 0 && tr[u].v <= K;
}
void get1() {
int u = root, p = 0, x = 0;
while (u) {
push_down(u);
p = u;
if (tr[u].v > 0) {
x = u;
u = tr[u].s[0];
}
else u = tr[u].s[1];
}
if (x) {
splay(x, 0);
tr[x].s[0] = 0;
push_up(x);
}
else tr[root = 0] = Node(0, 0);
}
void get2() {
int u = root, p = 0, x = 0;
while (u) {
push_down(u);
p = u;
if (tr[u].v <= K) {
x = u;
u = tr[u].s[1];
}
else u = tr[u].s[0];
}
int y;
if (x) splay(x, 0), y = tr[x].s[1];
else y = root;
if (y) {
long long cnt = tr[y].cnt, sum = tr[y].sum;
if (y != root) tr[x].s[1] = 0;
else tr[root = 0] = Node(0, 0);
push_up(root);
ins(K, cnt);
ins(1, (sum - cnt * K % MOD + MOD) % MOD);
}
}
int main() {
scanf("%d%d%lld", &n, &m, &K);
for (int i = 0; i < n; i++) {
int v;
scanf("%d", &v);
ins(v, 1);
}
while (m--) {
int x, y;
scanf("%d%d", &x, &y);
t_flag(root, x);
ins(y, 1);
get1();
get2();
printf("%lld\n", tr[root].sum);
}
return 0;
}