【Coel.学习笔记】Splay-区间翻转操作

摸了几天鱼ww
今天写一下 Splay 的区间翻转操作,也算是填上之前的大坑把~

操作详解

我们先来复习一下 Splay 的核心操作,借此引入一些翻转操作需要的新东西。

储存信息

一棵 Splay 需要维护的信息有:左右儿子的编号,父亲节点的编号,子树大小以及当前节点的值。

对于翻转操作,我们还需要维护一个懒惰标记,之后会用到。

struct node {
    int ch[2], fa, id;
    int size, lazy;
    void build(int v, int p) { id = v, fa = p, size = 1; } //初始化
} t[maxn];
int root, idx; //当前根节点和编号计数

大小汇总 pushup 标记下放 pushdown

大小汇总和 FHQ_Treap 一样,就是左子树加右子树加根节点。
标记下放用来处理翻转操作。判断当前节点是否存在标记,如果存在,则交换左右儿子并且对他们的懒标记取反,最后删除自己的标记。

void pushup(int x) {
    t[x].size = t[t[x].ch[0]].size + t[t[x].ch[1]].size + 1;
}

void pushdown(int x) {
    if (t[x].lazy) {
        swap(t[x].ch[0], t[x].ch[1]);
        t[t[x].ch[0]].lazy ^= 1;
        t[t[x].ch[1]].lazy ^= 1;
        t[x].lazy = 0;
    }
}

旋转和伸展

这一点还是比较复杂的,所以只给出具体实现。这里的实现和之前写的略有不同。

void rotate(int x) {
    int fa = t[x].fa, gfa = t[fa].fa; //父亲节点和祖父节点
    bool is_rson = (t[fa].ch[1] == x); //判断儿子类型,0 左 1 右
    t[gfa].ch[t[gfa].ch[1] == fa] = x, t[x].fa = gfa;
    t[fa].ch[is_rson] = t[x].ch[is_rson ^ 1], t[t[x].ch[is_rson ^ 1]].fa = fa;
    t[x].ch[is_rson ^ 1] = fa, t[fa].fa = x;
    pushup(fa), pushup(x);
}

void splay(int x, int k) { //把 x 转到 k 的下方
    while (t[x].fa != k) {
        int fa = t[x].fa, gfa = t[fa].fa;
        if (gfa != k)
            if ((t[fa].ch[1] == x) != (t[gfa].ch[1] == fa)) //三点不共线,先转儿子
                rotate(x);
            else rotate(fa); //三点共线,转父亲
        rotate(x); //三个情况都要转一次儿子
    }
    if (!k) root = x;
}

查询第 k 大

经典的递归操作,和 FHQ_Treap 一样。别忘了每次迭代都要下放标记。

int kth(int k) {
    int u = root;
    while (true) {
        pushdown(u);
        if (t[t[u].ch[0]].size >= k) u = t[u].ch[0];
        else if (t[t[u].ch[0]].size + 1 == k) return u;
        else k -= t[t[u].ch[0]].size + 1, u = t[u].ch[1];
    }
    return -1; //数据合理则不会存在这一情况
}

翻转过程

文艺平衡树这题不需要太多额外的东西,所以很好做。
翻转其实就三步:

  1. 对于区间 \(l,r\),找到 \(l\)\(r+2\) 的第 \(k\) 大对应值。这里找 \(r+2\) 是由于我们额外插入了两个哨兵节点 \(0,n+1\) 防止出现意外错误。
  2. 新找到的两个对应值记作 \(L,R\),将 \(L\) 伸展到根节点, \(R\) 伸展到 \(L\) 下方。
  3. \(R\) 的左儿子懒标记清除。

最后做一遍中序遍历就可以得到整棵子树的结果了。
完整代码如下:

#include <iostream>

using namespace std;

const int maxn = 1e5 + 10;

int n, m;

struct node {
    int ch[2], fa, id;
    int size, lazy;
    void build(int v, int p) { id = v, fa = p, size = 1; }
} t[maxn];
int root, idx;

void pushup(int x) {
    t[x].size = t[t[x].ch[0]].size + t[t[x].ch[1]].size + 1;
}

void pushdown(int x) {
    if (t[x].lazy) {
        swap(t[x].ch[0], t[x].ch[1]);
        t[t[x].ch[0]].lazy ^= 1;
        t[t[x].ch[1]].lazy ^= 1;
        t[x].lazy = 0;
    }
}

void rotate(int x) {
    int fa = t[x].fa, gfa = t[fa].fa;
    bool is_rson = (t[fa].ch[1] == x);
    t[gfa].ch[t[gfa].ch[1] == fa] = x, t[x].fa = gfa;
    t[fa].ch[is_rson] = t[x].ch[is_rson ^ 1], t[t[x].ch[is_rson ^ 1]].fa = fa;
    t[x].ch[is_rson ^ 1] = fa, t[fa].fa = x;
    pushup(fa), pushup(x);
}

void splay(int x, int k) {
    while (t[x].fa != k) {
        int fa = t[x].fa, gfa = t[fa].fa;
        if (gfa != k)
            if ((t[fa].ch[1] == x) != (t[gfa].ch[1] == fa))
                rotate(x);
            else rotate(fa);
        rotate(x);
    }
    if (!k) root = x;
}

void insert(int x) {
    int u = root, p = 0;
    while (u) p = u, u = t[u].ch[x > t[u].id];
    u = ++idx;
    if (p) t[p].ch[x > t[p].id] = u;
    t[u].build(x, p);
    splay(u, 0);
}

int kth(int k) {
    int u = root;
    while (true) {
        pushdown(u);
        if (t[t[u].ch[0]].size >= k) u = t[u].ch[0];
        else if (t[t[u].ch[0]].size + 1 == k) return u;
        else k -= t[t[u].ch[0]].size + 1, u = t[u].ch[1];
    }
    return -1;
}

void dfs(int u) {
    pushdown(u);
    if (t[u].ch[0]) dfs(t[u].ch[0]);
    if (t[u].id >= 1 && t[u].id <= n) cout << t[u].id << ' ';
    if (t[u].ch[1]) dfs(t[u].ch[1]);
}

int main(void) {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> m;
    for (int i = 0; i <= n + 1; i++) insert(i);
    while (m--) {
        int l, r;
        cin >> l >> r;
        l = kth(l), r = kth(r + 2);
        splay(l, 0), splay(r, l);
        t[t[r].ch[0]].lazy ^= 1;
    }
    dfs(root);
    return 0;
}

接下来看一道 Splay 的经典题。

[NOI2005] 维护数列

洛谷传送门
维护一个给定数列,要求实现如下操作:

编号 名称 格式 说明
1 插入 \(\operatorname{INSERT}\ posi \ tot \ c_1 \ c_2 \cdots c_{tot}\) 在当前数列的第 \(posi\) 个数字后插入 \(tot\) 个数字:\(c_1, c_2 \cdots c_{tot}\);若在数列首插入,则 \(posi\)\(0\)
2 删除 \(\operatorname{DELETE} \ posi \ tot\) 从当前数列的第 \(posi\) 个数字开始连续删除 \(tot\) 个数字
3 修改 \(\operatorname{MAKE-SAME} \ posi \ tot \ c\) 从当前数列的第 \(posi\) 个数字开始的连续 \(tot\) 个数字统一修改为 \(c\)
4 翻转 \(\operatorname{REVERSE} \ posi \ tot\) 取出从当前数列的第 \(posi\) 个数字开始的 \(tot\) 个数字,翻转后放入原来的位置
5 求和 \(\operatorname{GET-SUM} \ posi \ tot\) 计算从当前数列的第 \(posi\) 个数字开始的 \(tot\) 个数字的和并输出
6 求最大子段和 \(\operatorname{MAX-SUM}\) 求出当前数列中和最大的一段子列,并输出最大和

对于每个操作 \(5\) 和操作 \(6\) ,输出对应结果。


解析:很复杂的一道紫题,绝对的最上位……

\(5\) 个操作都是很基本的 Splay 操作,维护懒标记即可。我们具体看看第 \(6\) 个操作:求最大子段和。

回顾一下线段树求最大子段和的操作:利用分治思想,维护父节点的最大子段和。显然只会有三种可能:全部位于左儿子,全部位于右儿子,左右儿子各占一段。再根据一点贪心的想法,对于第三种可能,取到的子段一定是左儿子尾段的最大和加上右儿子尾段的最大和。据此对每个节点记录下最大子段和、前缀和与后缀和,这样就可以再查询时快速得到结果。

回到这一题,由于翻转区间、插入区间和删除区间无法在线段树实现,所以只能用 Splay 解决这个问题。注意写 pushdownpushup 时要注意分析推平操作和翻转操作的关系,并且结合刚才提到的最大子段和来操作。这里为了方便使用了传引用表示所需节点的左右儿子,代码更简单。

此外由于数据多、内存限制严格,我们需要使用一个叫做内存回收的技巧。把删除的节点编号放在一个数组里面,在插入时有限使用数组中的编号即可,这样就保证了内存使用量只与最大数列 (\(5\times 10^5\))有关。

代码如下(排在最优解第 \(9\),还是挺快的ww):

#include <cctype>
#include <cstdio>
#include <cstring>
#include <iostream>

const int maxn = 5e5 + 10, inf = 1e9;

using namespace std;

int n, m;
struct node {
    int ch[2], fa, val;
    int rev, assi;  //翻转标记与推平标记
    int size, sum, ms, ls, rs;
    void init(int v, int p) {
        ch[0] = ch[1] = 0, fa = p, val = v;
        rev = assi = 0;
        size = 1, sum = ms = val;
        ls = rs = max(val, 0);
    }
} t[maxn];
int root, nodes[maxn], idx;
int w[maxn];

inline int read() {
    int x = 0, f = 1;
    char ch = getchar();
    while (!isdigit(ch)) {
        if (ch == '-') f = -1;
        ch = getchar();
    }
    while (isdigit(ch)) x = x * 10 + ch - '0', ch = getchar();
    return x * f;
}

inline void write(int x, int Ctr_sign) {
    static int buf[30];
    int top = 0;
    if (x < 0) {
        x = -x;
        putchar('-');
    }
    do {
        buf[top++] = x % 10;
        x /= 10;
    } while (x);
    while (top) putchar(buf[--top] + '0');
    putchar(Ctr_sign);
}

void pushup(int x) {
    node &u = t[x], &l = t[u.ch[0]], &r = t[u.ch[1]]; //利用传引用修改
    u.size = l.size + r.size + 1;
    u.sum = l.sum + r.sum + u.val;
    u.ls = max(l.ls, l.sum + u.val + r.ls);
    u.rs = max(r.rs, r.sum + u.val + l.rs);
    u.ms = max(max(l.ms, r.ms), l.rs + u.val + r.ls);
}

void pushdown(int x) {
    node &u = t[x], &l = t[u.ch[0]], &r = t[u.ch[1]];
    if (u.assi) {
        u.assi = u.rev = 0;
        if (u.ch[0]) l.assi = 1, l.val = u.val, l.sum = l.val * l.size;
        if (u.ch[1]) r.assi = 1, r.val = u.val, r.sum = r.val * r.size;
        if (u.val > 0) {
            if (u.ch[0]) l.ms = l.ls = l.rs = l.sum;
            if (u.ch[1]) r.ms = r.ls = r.rs = r.sum;
        } else {
            if (u.ch[0]) l.ms = l.val, l.ls = l.rs = 0;
            if (u.ch[1]) r.ms = r.val, r.ls = r.rs = 0;
        }
    } else if (u.rev) {
        u.rev = 0, l.rev ^= 1, r.rev ^= 1;
        swap(l.ls, l.rs), swap(r.ls, r.rs);
        swap(l.ch[0], l.ch[1]), swap(r.ch[0], r.ch[1]);
    }
}

int build(int l, int r, int fa) {
    int mid = (l + r) >> 1;
    int u = nodes[idx--];
    t[u].init(w[mid], fa);
    if (l < mid) t[u].ch[0] = build(l, mid - 1, u);
    if (mid < r) t[u].ch[1] = build(mid + 1, r, u);
    pushup(u);
    return u;
}

void recycle(int u) { //内存回收,把丢弃的子树编号放进数组里面
    if (t[u].ch[0]) recycle(t[u].ch[0]);
    if (t[u].ch[1]) recycle(t[u].ch[1]);
    nodes[++idx] = u;
}

void rotate(int x) {
    int fa = t[x].fa, gfa = t[fa].fa;
    int k = (t[fa].ch[1] == x);
    t[gfa].ch[t[gfa].ch[1] == fa] = x, t[x].fa = gfa;
    t[fa].ch[k] = t[x].ch[k ^ 1], t[t[x].ch[k ^ 1]].fa = fa;
    t[x].ch[k ^ 1] = fa, t[fa].fa = x;
    pushup(fa), pushup(x);
}

void splay(int x, int k) {
    while (t[x].fa != k) {
        int fa = t[x].fa, gfa = t[fa].fa;
        if (gfa != k) {
            if ((t[fa].ch[1] == x) ^ (t[gfa].ch[1] == fa))
                rotate(x);
            else
                rotate(fa);
        }
        rotate(x);
    }
    if (!k) root = x;
}

int kth(int k) {
    int u = root;
    while (u) {
        pushdown(u);
        if (t[t[u].ch[0]].size >= k)
            u = t[u].ch[0];
        else if (t[t[u].ch[0]].size + 1 == k)
            return u;
        else
            k -= t[t[u].ch[0]].size + 1, u = t[u].ch[1];
    }
    return -1;
}

int main(void) {
    for (int i = 1; i < maxn; i++) nodes[++idx] = i;
    n = read(), m = read();
    t[0].ms = w[0] = w[n + 1] = -inf;
    for (int i = 1; i <= n; i++) w[i] = read();
    root = build(0, n + 1, 0);
    while (m--) {
        char op[30];
        scanf("%s", op);
        if (!strcmp(op, "INSERT")) {
            int pos = read(), tot = read();
            for (int i = 0; i < tot; i++) w[i] = read();
            int l = kth(pos + 1), r = kth(pos + 2);
            splay(l, 0), splay(r, l);
            int u = build(0, tot - 1, r);
            t[r].ch[0] = u;
            pushup(r), pushup(l);
        } else if (!strcmp(op, "DELETE")) {
            int pos = read(), tot = read();
            int l = kth(pos), r = kth(pos + tot + 1);
            splay(l, 0), splay(r, l);
            recycle(t[r].ch[0]);
            t[r].ch[0] = 0;
            pushup(r), pushup(l);
        } else if (!strcmp(op, "MAKE-SAME")) {
            int pos = read(), tot = read(), c = read();
            int l = kth(pos), r = kth(pos + tot + 1);
            splay(l, 0), splay(r, l);
            node& son = t[t[r].ch[0]];
            son.assi = 1, son.val = c, son.sum = c * son.size;
            if (c > 0)
                son.ms = son.ls = son.rs = son.sum;
            else
                son.ms = c, son.ls = son.rs = 0;
            pushup(r), pushup(l);
        } else if (!strcmp(op, "REVERSE")) {
            int pos = read(), tot = read();
            int l = kth(pos), r = kth(pos + tot + 1);
            splay(l, 0), splay(r, l);
            node& son = t[t[r].ch[0]];
            son.rev ^= 1;
            swap(son.ls, son.rs);
            swap(son.ch[0], son.ch[1]);
            pushup(r), pushup(l);
        } else if (!strcmp(op, "GET-SUM")) {
            int pos = read(), tot = read();
            int l = kth(pos), r = kth(pos + tot + 1);
            splay(l, 0), splay(r, l);
            write(t[t[r].ch[0]].sum, '\n');
        } else
            write(t[root].ms, '\n');
    }
    return 0;
}
posted @ 2022-07-26 17:35  秋泉こあい  阅读(74)  评论(0编辑  收藏  举报