洛谷P8767 [蓝桥杯 2021 国 A] 冰山 题解 splay tree

题目链接:https://www.luogu.com.cn/problem/P8767

鸣谢:这道题的顺利解决得到了 7KByte 大佬的大力帮助,在此再次表示感谢。

首先,我的想法是这样的:

使用一个 splay tree 来维护这些冰山的信息。

对于每次操作的 \(x\)\(y\)

  1. 先将所有节点的权值增加 \(x\)(这步操作只需要修改根节点的信息,然后进行一下懒惰标记即可,以后会 push_down 下去的)
  2. 然后增加一个权值为 \(y\) 的点(如果权值为 \(y\) 的点本身就存在的话,就将其数量增加 \(1\)
  3. 先将权值 \(\gt 0\) 且最小的节点 splay 为根节点,然后删除其左子树(因为此时根节点的左儿子的权值均 \(\le 0\)
  4. 再将权值 \(\le k\) 且最大的节点 splay 为根节点,此时其右子树对应的节点都是权值 \(\gt k\) 的,这些节点需要拆解为若干个权值为 \(k\) 的节点和若干个权值为 \(1\) 的节点,操作过程为:
    1. 首先需要为每一个节点维护两个信息:
      • sum 表示以该节点为根节点的子树中所有节点权值与数量的乘积之和,比如:如果该节点所在子树中有 \(2\) 个权值为 \(3\) 的点,\(6\) 个权值为 \(5\) 的点,则该节点的 sum 为 \(2 \times 3 + 6 \times 5 = 36\)
      • cnt 表示以该节点为根节点的自述中所有节点的数量,比如:如果该绩点所在子树中有 \(2\) 个权值为 \(3\) 的点,\(6\) 个权值为 \(5\) 的点,则该节点的 cnt 值为 \(2 + 6 = 8\)
    2. 然后,先计算出根节点的右儿子的 sum 值和 cnt 值,可以发现,这些节点中的每一个都会被分成 \(1\) 个权值为 \(k\) 的节点以及(剩余的)若干个权值为 \(1\) 的节点,所以我们需要做的事情是:
      1. 删除根节点的右子树
      2. 插入 cnt 个权值为 \(k\) 的节点
      3. 插入 sum - cnt * k 个权值为 \(1\) 的节点

(注:这样讨论在下面的实现是还有一些细节没有处理,我们在下面的内容中讨论)

具体实现时,首先由于我经常会用到变量小 k,所以我把题目描述中的 k 开成了大K(即 K)—— 即冰山的最大大小。

然后我定义了一个结构体来维护 splay tree 里面的节点信息,如下:

struct Node {
    int s[2], p;    // s[0] 左儿子 s[1] 右儿子 p 父节点
    long long v,    // 冰山体积
              num,  // 冰山个数
              cnt,  // 子树包含冰山个数
              sum,  // 子树包含冰山体积之和
              flag; // 懒惰标记

    Node() {};
    Node(long long _v, int _p) {v = _v; p = _p; s[0] = s[1] = 0; num = cnt = sum = flag = 0;}
} tr[maxn];

其中:

  • s[0]s[1] 分别表示左儿子和右儿子节点的编号(如果没有则为 \(0\)
  • v 表示当前节点的权值(对应的就是冰山的体积)
  • num 表示体积为 v 的冰山有多少个
  • cnt 表示以当前节点为根节点的子树中包含的冰山的总数
  • sum 表示以当前节点为根节点的子树中包含的冰山的总体积
  • flag 是懒惰标记,它表示当前节点的所有子节点的体积需要整体增加的量

这里需要注明的是,我自己做懒惰标记的习惯是:修改当前节点信息的同时进行懒惰标记,然后将懒惰标记传给子节点的时候也会修改子节点的信息,因为写线段树的时候都是这么写的就习惯了。(因为有的大佬可能习惯是 pushdown 的时候更新当前节点,但是我已经习惯更新当前节点并进行懒惰标记,然后 pushdown 的时候一方面更新子节点,另一方面把懒惰标记传给子节点)

然后因为具体实现是经常需要进行形如 a = (a + b) % MOD; 的操作,所以我添加了一个 add 函数方便写:

void add(long long &a, long long b) {
    a = (a + b % MOD) % MOD;
}

然后就是比较重要的 push up 和 push down 操作了。

push up

push_up 需要将子节点的信息更新到当前节点,只需要更新一下 cnt 和 sum

主要操作就是:

tr[x].cnt = tr[x].num + tr[tr[x].s[0]].cnt + tr[tr[x].s[1]].cnt;
tr[x].sum = tr[x].num * tr[x].v + tr[tr[x].s[0]].sum + tr[tr[x].s[1]].sum;

需要取一下模。

push down

push_down 需要将当前节点的懒惰标记(flag,对应的是冰山体积的整体增量)传递给子节点并且同时更新子节点。对于一个节点来说,当传递了一个值为 tmp 的体积增量时:

  • flag 会增加 tmp
  • v 会增加 tmp
  • sum 会增加 cnt * tmp(因为该节点的子树中所有节点对应的冰山体积都会增加 tmp

主要操作是:

void t_flag(int x, long long tmp) {
    if (x) {
        tr[x].flag += tmp;
        tr[x].v += tmp;
        add(tr[x].sum, tr[x].cnt * tmp);
    }
}

void push_down(int x) {
    if (tr[x].flag) {
        t_flag(tr[x].s[0], tr[x].flag);
        t_flag(tr[x].s[1], tr[x].flag);
        tr[x].flag = 0;
    }
}

旋转和 splay 操作

这部分的操作基本没有改动过,之前用它们解决过 AcWing 上的 splay 例题(原来的帖子:https://www.acwing.com/file_system/file/content/whole/index/content/7428637/

对应的代码(多了一个 f_s(p, u, k) 函数用来认亲(u 是 p 的儿子,其中 k = 0 表示 u 是 p 的左儿子;k = 1 表示 u 是 p 的右儿子):

void f_s(int p, int u, bool k) {
    tr[p].s[k] = u;
    tr[u].p = p;
}

void rot(int x) {
    int y = tr[x].p, z = tr[y].p;
    bool k = tr[y].s[1] == x;
    f_s(z, x, tr[z].s[1]==y);
    f_s(y, tr[x].s[k^1], k);
    f_s(x, y, k^1);
    push_up(y), push_up(x);
}

void splay(int x, int k) {
    while (tr[x].p != k) {
        int y = tr[x].p, z = tr[y].p;
        if (z != k)
            (tr[y].s[1]==x) ^ (tr[z].s[1]==y) ? rot(x) : rot(y);
        rot(x);
    }
    if (!k) root = x;
}

插入操作

ins(v, num) 表示的是加入 num 个体积为 v 的冰山:

void ins(long long v, long long num) {
    int u = root, p = 0;
    while (u && tr[u].v != v) {
        push_down(u);
        p = u, u = tr[u].s[v > tr[u].v];
    }
    if (!u) {
        tr[u = ++idx] = Node(v, p);
        if (p) tr[p].s[v > tr[p].v] = u;
    }
    else
        push_down(u);   // 如果不是新建的节点,需要push_down一下
    add(tr[u].num, num);
    add(tr[u].cnt, num);
    add(tr[u].sum, num * v);
    splay(u, 0);
}

调整

这里有一个 check 函数,主要用来判断当前节点对应的冰山体积是否合法:

bool check(int u) {
    return tr[u].v > 0 && tr[u].v <= K;
}

然后就是 get1() 函数和 get2() 函数了。

get1() 函数对应的就是我上面说的第 3 步操作:

  1. 先将权值 \(\gt 0\) 且最小的节点 splay 为根节点,然后删除其左子树(因为此时根节点的左儿子的权值均 \(\le 0\)

补充一点细节: 如果循环结束时没有找到 \(\gt 0\) 的节点(即下面代码中的 x 结束时仍然为 \(0\)),说明需要将整棵树删除。

void get1() {
    int u = root, p = 0, x = 0;
    while (u) {
        push_down(u);
        p = u;
        if (tr[u].v > 0) {
            x = u;
            u = tr[u].s[0];
        }
        else u = tr[u].s[1];
    }
    if (x) {
        splay(x, 0);
        tr[x].s[0] = 0;
        push_up(x);
    }
    else tr[root = 0] = Node(0, 0);
}

get2() 函数对应的就是我上面说的第 4 步操作:

  1. 再将权值 \(\le k\) 且最大的节点 splay 为根节点,此时其右子树对应的节点都是权值 \(\gt k\) 的,这些节点需要拆解为若干个权值为 \(k\) 的节点和若干个权值为 \(1\) 的节点

补充一点细节: 如果循环结束时没有找到 \(\gt 0\) 的节点(即下面代码中的 x 结束时仍然为 \(0\)),说明要么树是空的,要么所有节点都是 \(\gt K\) 的,如果所有节点都是 \(\gt K\) 的,应该处理整棵树,而不是根节点的右儿子。

void get2() {
    int u = root, p = 0, x = 0;
    while (u) {
        push_down(u);
        p = u;
        if (tr[u].v <= K) {
            x = u;
            u = tr[u].s[1];
        }
        else u = tr[u].s[0];
    }
    int y;
    if (x) splay(x, 0), y = tr[x].s[1];
    else y = root;
    if (y) {
        long long cnt = tr[y].cnt, sum = tr[y].sum;
        if (y != root) tr[x].s[1] = 0;
        else tr[root = 0] = Node(0, 0);
        push_up(root);
        ins(K, cnt);
        ins(1, (sum - cnt * K % MOD + MOD) % MOD);
    }
}

完整代码如下:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e6 + 5;
const long long MOD = 998244353;

int n, m;
long long K;
struct Node {
    int s[2], p;    // s[0] 左儿子 s[1] 右儿子 p 父节点
    long long v,    // 冰山体积
              num,  // 冰山个数
              cnt,  // 子树包含冰山个数
              sum,  // 子树包含冰山体积之和
              flag; // 懒惰标记

    Node() {};
    Node(long long _v, int _p) {v = _v; p = _p; s[0] = s[1] = 0; num = cnt = sum = flag = 0;}
} tr[maxn];
int root, idx;

void add(long long &a, long long b) {
    a = (a + b % MOD) % MOD;
}

void push_up(int x) {
    tr[x].cnt = tr[x].num + tr[tr[x].s[0]].cnt + tr[tr[x].s[1]].cnt;
    tr[x].cnt %= MOD;
    tr[x].sum = tr[x].num * tr[x].v + tr[tr[x].s[0]].sum + tr[tr[x].s[1]].sum;
    tr[x].sum %= MOD;
    tr[x].sum = (tr[x].sum + MOD) % MOD;
}

void t_flag(int x, long long tmp) {
    if (x) {
        tr[x].flag += tmp;
        tr[x].v += tmp;
        add(tr[x].sum, tr[x].cnt * tmp);
    }
}

void push_down(int x) {
    if (tr[x].flag) {
        t_flag(tr[x].s[0], tr[x].flag);
        t_flag(tr[x].s[1], tr[x].flag);
        tr[x].flag = 0;
    }
}

void f_s(int p, int u, bool k) {
    if (p) tr[p].s[k] = u;
    tr[u].p = p;
}

void rot(int x) {
    int y = tr[x].p, z = tr[y].p;
    bool k = tr[y].s[1] == x;
    f_s(z, x, tr[z].s[1]==y);
    f_s(y, tr[x].s[k^1], k);
    f_s(x, y, k^1);
    push_up(y), push_up(x);
}

void splay(int x, int k) {
    while (tr[x].p != k) {
        int y = tr[x].p, z = tr[y].p;
        if (z != k)
            (tr[y].s[1]==x) ^ (tr[z].s[1]==y) ? rot(x) : rot(y);
        rot(x);
    }
    if (!k) root = x;
}

void ins(long long v, long long num) {
    int u = root, p = 0;
    while (u && tr[u].v != v) {
        push_down(u);
        p = u, u = tr[u].s[v > tr[u].v];
    }
    if (!u) {
        tr[u = ++idx] = Node(v, p);
        if (p) tr[p].s[v > tr[p].v] = u;
    }
    else
        push_down(u);   // 如果不是新建的节点,需要push_down一下
    add(tr[u].num, num);
    add(tr[u].cnt, num);
    add(tr[u].sum, num * v);
    splay(u, 0);
}

bool check(int u) {
    return tr[u].v > 0 && tr[u].v <= K;
}

void get1() {
    int u = root, p = 0, x = 0;
    while (u) {
        push_down(u);
        p = u;
        if (tr[u].v > 0) {
            x = u;
            u = tr[u].s[0];
        }
        else u = tr[u].s[1];
    }
    if (x) {
        splay(x, 0);
        tr[x].s[0] = 0;
        push_up(x);
    }
    else tr[root = 0] = Node(0, 0);
}

void get2() {
    int u = root, p = 0, x = 0;
    while (u) {
        push_down(u);
        p = u;
        if (tr[u].v <= K) {
            x = u;
            u = tr[u].s[1];
        }
        else u = tr[u].s[0];
    }
    int y;
    if (x) splay(x, 0), y = tr[x].s[1];
    else y = root;
    if (y) {
        long long cnt = tr[y].cnt, sum = tr[y].sum;
        if (y != root) tr[x].s[1] = 0;
        else tr[root = 0] = Node(0, 0);
        push_up(root);
        ins(K, cnt);
        ins(1, (sum - cnt * K % MOD + MOD) % MOD);
    }
}

int main() {
    scanf("%d%d%lld", &n, &m, &K);
    for (int i = 0; i < n; i++) {
        int v;
        scanf("%d", &v);
        ins(v, 1);
    }
    while (m--) {
        int x, y;
        scanf("%d%d", &x, &y);
        t_flag(root, x);
        ins(y, 1);
        get1();
        get2();
        printf("%lld\n", tr[root].sum);
    }
    return 0;
}
posted @ 2022-12-08 18:46  quanjun  阅读(77)  评论(0编辑  收藏  举报