洛谷 P3391 【模板】文艺平衡树

Description

P3391 【模板】文艺平衡树

Solution

方法一(fhq-treap)

文艺平衡树实际上只有一个操作,区间翻转

那么我们来看看如何实现。

大体上是一致的,但是分裂时要有所改变。

平常我们写的 \(fhq-treap\) 是以数值来分裂的,这里我们换个分裂方式。

以子树大小分裂

什么意思呢?

其实跟查询排名为 \(x\) 的数的值差不多。

就是不停地往左子树跳,或者往右子树跳,去寻找值。

这里我们先分析一下代码。

分裂

inline void split(int x, int k, int &a, int &b){
    if(!x){
        a = b = 0;
        return;
    }
    pushdown(x);
    if(t[ls(x)].siz < k){                        //k 大于 左子树大小
        a = x;
        split(rs(x), k - t[ls(x)].siz - 1, rs(x), b);    //去右子树查询,(-1 是减去父节点)
    }else{                                //在左子树
        b = x;
        split(ls(x), k, a, ls(x));                //去左子树找
    }
    pushup(x);
}

翻转

void reverse(int l, int r){
    split(root, l - 1, a, b);        // a: [1, l - 1]    b: [l, n]        子树 a, b, c 表示的区间范围,自己理解一下
    split(b, r - l + 1, b, c);        // b: [l, r]        c: [r + 1, n]
    t[b].flag ^= 1;                //打上标记
    root = merge(a, merge(b, c));            //合并回去
}

下传标记

inline void pushdown(int x){
    if(t[x].flag){
        swap(ls(x), rs(x));        //直接翻转
        if(ls(x)) t[ls(x)].flag ^= 1;    //下传标记
        if(rs(x)) t[rs(x)].flag ^= 1;
        t[x].flag = 0;
    }
}

因为题目要求输出完整序列,所以 \(dfs\) 输出,别忘了输出的时候 \(pushdown\) (虽然我也不知道有没有影响,懒得试了)

Code

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstdlib>
#define ls(x) t[x].ch[0]
#define rs(x) t[x].ch[1]

using namespace std;

const int N = 2e5 + 10;
int n, m, tot, root;
int a, b, c;

inline int read(){
    int x = 0;
    char ch = getchar();
    while(ch < '0' || ch > '9') ch = getchar();
    while(ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
    return x;
}

struct Tree{
    int val, ch[2], wei, siz, flag;
}t[N];

inline void pushup(int x){
    t[x].siz = t[ls(x)].siz + t[rs(x)].siz + 1;
}

inline void pushdown(int x){
    if(t[x].flag){
        swap(ls(x), rs(x));
        if(ls(x)) t[ls(x)].flag ^= 1;
        if(rs(x)) t[rs(x)].flag ^= 1;
        t[x].flag = 0;
    }
}

inline void split(int x, int k, int &a, int &b){
    if(!x){
        a = b = 0;
        return;
    }
    pushdown(x);
    if(t[ls(x)].siz < k){
        a = x;
        split(rs(x), k - t[ls(x)].siz - 1, rs(x), b);
    }else{
        b = x;
        split(ls(x), k, a, ls(x));
    }
    pushup(x);
}

inline int merge(int x, int y){
    if(!x || !y) return x + y;
    if(t[x].wei < t[y].wei){
        pushdown(x);
        rs(x) = merge(rs(x), y);
        pushup(x);
        return x;
    }else{
        pushdown(y);
        ls(y) = merge(x, ls(y));
        pushup(y);
        return y;
    }
}

void insert(int k){
    t[++tot].val = k, t[tot].siz = 1, t[tot].wei = rand();
    root = merge(root, tot);
}

void reverse(int l, int r){
    split(root, l - 1, a, b);
    split(b, r - l + 1, b, c);
    t[b].flag ^= 1;
    root = merge(a, merge(b, c));
}

void print(int x){
    if(!x) return;
    pushdown(x);
    print(ls(x));
    printf("%d ", t[x].val);
    print(rs(x));
}

int main(){
    n = read(), m = read();
    for(int i = 1; i <= n; i++)
        insert(i);
    while(m--){
        int l, r;
        l = read(), r = read();
        reverse(l, r);
    }
    print(root);
    printf("\n");
    return 0;
}

方法二(Splay)

其实就是 \(Splay\) 的板子,把脚标当作权值即可。

注意前后都插入一个数防止越界。

Code

#include <bits/stdc++.h>
#define ls(x) t[x].ch[0]
#define rs(x) t[x].ch[1]

using namespace std;

namespace IO{
    inline int read(){
        int x = 0;
        char ch = getchar();
        while(!isdigit(ch)) ch = getchar();
        while(isdigit(ch)) x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
        return x;
    }

    template <typename T> inline void write(T x){
        if(x > 9) write(x / 10);
        putchar(x % 10 + '0');
    }
}
using namespace IO;

const int N = 1e5 + 10;
int n, m;
int a[N];
struct Splay{
    int val, siz, cnt, ch[2], fa;
    bool rev;
}t[N];
int root, tot;

inline void pushup(int x){
    t[x].siz = t[ls(x)].siz + t[rs(x)].siz + 1;
}

inline void pushdown(int x){
    if(t[x].rev){
        t[ls(x)].rev ^= 1, t[rs(x)].rev ^= 1;
        swap(ls(x), rs(x));
        t[x].rev = 0;
    }
}

inline void rotate(int x){
    int y = t[x].fa, z = t[y].fa;
    int k = rs(y) == x;
    t[z].ch[rs(z) == y] = x, t[x].fa = z;
    t[y].ch[k] = t[x].ch[k ^ 1], t[t[x].ch[k ^ 1]].fa = y;
    t[x].ch[k ^ 1] = y, t[y].fa = x;
    pushup(y), pushup(x);
}

inline void splay(int x, int goal){
    while(t[x].fa != goal){
        int y = t[x].fa, z = t[y].fa;
        if(z != goal) rotate((ls(z) == y) ^ (ls(y) == x) ? x : y);
        rotate(x);
    }
    if(!goal) root = x;
}

inline void find(int x){
    int u = root;
    if(!u) return;
    while(t[u].ch[x > t[u].val] && x != t[u].val) u = t[u].ch[x > t[u].val];
    splay(u, 0);
}

inline int check(int x, int type){
    find(x);
    int u = root;
    if((t[u].val > x && type) || (t[u].val < x && !type)) return u;
    u = t[u].ch[type];
    while(t[u].ch[type ^ 1]) u = t[u].ch[type ^ 1];
    return u;
}

inline void insert(int x){
    int u = root, fa = 0;
    while(u && t[u].val != x) fa = u, u = t[u].ch[x > t[u].val];
    if(u) t[u].cnt++;
    else{
        u = ++tot;
        if(fa) t[fa].ch[x > t[fa].val] = u;
        t[u].ch[0] = t[u].ch[1] = t[u].rev = 0;
        t[u].cnt = t[u].siz = 1, t[u].val = x, t[u].fa = fa;
    }
    splay(u, 0);
}

inline void remove(int k){
    int x = check(k, 0), y = check(k, 1);
    splay(x, 0), splay(y, x);
    int del = ls(y);
    if(t[del].cnt > 1) t[del].cnt--, splay(del, 0);
    else t[y].ch[0] = 0;
    pushup(y), pushup(x);
}

inline int get_val(int x, int k){
    pushdown(x);
    if(k <= t[ls(x)].siz) return get_val(ls(x), k);
    else if(k > t[ls(x)].siz + t[x].cnt) return get_val(rs(x), k - t[ls(x)].siz - t[x].cnt);
    return t[x].val;
}

inline void solve(int l, int r){
    l = get_val(root, l), r = get_val(root, r + 2);
    splay(l, 0), splay(r, l);
    t[ls(rs(root))].rev ^= 1;
}

inline void print(int x){
    pushdown(x);
    if(ls(x)) print(ls(x));
    if(t[x].val > 1 && t[x].val < n + 2) write(t[x].val - 1), putchar(' ');
    if(rs(x)) print(rs(x));
}

int main(){
    n = read(), m = read();
    for(int i = 1; i <= n + 2; ++i) insert(i);
    for(int i = 1; i <= m; ++i){
        int l = read(), r = read();
        solve(l, r);
    }
    print(root), puts("");
    return 0;
}

\[\_EOF\_ %%\]

posted @ 2021-08-18 17:07  xixike  阅读(73)  评论(0编辑  收藏  举报