洛谷 P3391 【模板】文艺平衡树
Description
Solution
方法一(fhq-treap)
文艺平衡树实际上只有一个操作,区间翻转。
那么我们来看看如何实现。
大体上是一致的,但是分裂时要有所改变。
平常我们写的 \(fhq-treap\) 是以数值来分裂的,这里我们换个分裂方式。
以子树大小分裂
什么意思呢?
其实跟查询排名为 \(x\) 的数的值差不多。
就是不停地往左子树跳,或者往右子树跳,去寻找值。
这里我们先分析一下代码。
分裂
inline void split(int x, int k, int &a, int &b){
if(!x){
a = b = 0;
return;
}
pushdown(x);
if(t[ls(x)].siz < k){ //k 大于 左子树大小
a = x;
split(rs(x), k - t[ls(x)].siz - 1, rs(x), b); //去右子树查询,(-1 是减去父节点)
}else{ //在左子树
b = x;
split(ls(x), k, a, ls(x)); //去左子树找
}
pushup(x);
}
翻转
void reverse(int l, int r){
split(root, l - 1, a, b); // a: [1, l - 1] b: [l, n] 子树 a, b, c 表示的区间范围,自己理解一下
split(b, r - l + 1, b, c); // b: [l, r] c: [r + 1, n]
t[b].flag ^= 1; //打上标记
root = merge(a, merge(b, c)); //合并回去
}
下传标记
inline void pushdown(int x){
if(t[x].flag){
swap(ls(x), rs(x)); //直接翻转
if(ls(x)) t[ls(x)].flag ^= 1; //下传标记
if(rs(x)) t[rs(x)].flag ^= 1;
t[x].flag = 0;
}
}
因为题目要求输出完整序列,所以 \(dfs\) 输出,别忘了输出的时候 \(pushdown\) (虽然我也不知道有没有影响,懒得试了)
Code
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstdlib>
#define ls(x) t[x].ch[0]
#define rs(x) t[x].ch[1]
using namespace std;
const int N = 2e5 + 10;
int n, m, tot, root;
int a, b, c;
inline int read(){
int x = 0;
char ch = getchar();
while(ch < '0' || ch > '9') ch = getchar();
while(ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
return x;
}
struct Tree{
int val, ch[2], wei, siz, flag;
}t[N];
inline void pushup(int x){
t[x].siz = t[ls(x)].siz + t[rs(x)].siz + 1;
}
inline void pushdown(int x){
if(t[x].flag){
swap(ls(x), rs(x));
if(ls(x)) t[ls(x)].flag ^= 1;
if(rs(x)) t[rs(x)].flag ^= 1;
t[x].flag = 0;
}
}
inline void split(int x, int k, int &a, int &b){
if(!x){
a = b = 0;
return;
}
pushdown(x);
if(t[ls(x)].siz < k){
a = x;
split(rs(x), k - t[ls(x)].siz - 1, rs(x), b);
}else{
b = x;
split(ls(x), k, a, ls(x));
}
pushup(x);
}
inline int merge(int x, int y){
if(!x || !y) return x + y;
if(t[x].wei < t[y].wei){
pushdown(x);
rs(x) = merge(rs(x), y);
pushup(x);
return x;
}else{
pushdown(y);
ls(y) = merge(x, ls(y));
pushup(y);
return y;
}
}
void insert(int k){
t[++tot].val = k, t[tot].siz = 1, t[tot].wei = rand();
root = merge(root, tot);
}
void reverse(int l, int r){
split(root, l - 1, a, b);
split(b, r - l + 1, b, c);
t[b].flag ^= 1;
root = merge(a, merge(b, c));
}
void print(int x){
if(!x) return;
pushdown(x);
print(ls(x));
printf("%d ", t[x].val);
print(rs(x));
}
int main(){
n = read(), m = read();
for(int i = 1; i <= n; i++)
insert(i);
while(m--){
int l, r;
l = read(), r = read();
reverse(l, r);
}
print(root);
printf("\n");
return 0;
}
方法二(Splay)
其实就是 \(Splay\) 的板子,把脚标当作权值即可。
注意前后都插入一个数防止越界。
Code
#include <bits/stdc++.h>
#define ls(x) t[x].ch[0]
#define rs(x) t[x].ch[1]
using namespace std;
namespace IO{
inline int read(){
int x = 0;
char ch = getchar();
while(!isdigit(ch)) ch = getchar();
while(isdigit(ch)) x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
return x;
}
template <typename T> inline void write(T x){
if(x > 9) write(x / 10);
putchar(x % 10 + '0');
}
}
using namespace IO;
const int N = 1e5 + 10;
int n, m;
int a[N];
struct Splay{
int val, siz, cnt, ch[2], fa;
bool rev;
}t[N];
int root, tot;
inline void pushup(int x){
t[x].siz = t[ls(x)].siz + t[rs(x)].siz + 1;
}
inline void pushdown(int x){
if(t[x].rev){
t[ls(x)].rev ^= 1, t[rs(x)].rev ^= 1;
swap(ls(x), rs(x));
t[x].rev = 0;
}
}
inline void rotate(int x){
int y = t[x].fa, z = t[y].fa;
int k = rs(y) == x;
t[z].ch[rs(z) == y] = x, t[x].fa = z;
t[y].ch[k] = t[x].ch[k ^ 1], t[t[x].ch[k ^ 1]].fa = y;
t[x].ch[k ^ 1] = y, t[y].fa = x;
pushup(y), pushup(x);
}
inline void splay(int x, int goal){
while(t[x].fa != goal){
int y = t[x].fa, z = t[y].fa;
if(z != goal) rotate((ls(z) == y) ^ (ls(y) == x) ? x : y);
rotate(x);
}
if(!goal) root = x;
}
inline void find(int x){
int u = root;
if(!u) return;
while(t[u].ch[x > t[u].val] && x != t[u].val) u = t[u].ch[x > t[u].val];
splay(u, 0);
}
inline int check(int x, int type){
find(x);
int u = root;
if((t[u].val > x && type) || (t[u].val < x && !type)) return u;
u = t[u].ch[type];
while(t[u].ch[type ^ 1]) u = t[u].ch[type ^ 1];
return u;
}
inline void insert(int x){
int u = root, fa = 0;
while(u && t[u].val != x) fa = u, u = t[u].ch[x > t[u].val];
if(u) t[u].cnt++;
else{
u = ++tot;
if(fa) t[fa].ch[x > t[fa].val] = u;
t[u].ch[0] = t[u].ch[1] = t[u].rev = 0;
t[u].cnt = t[u].siz = 1, t[u].val = x, t[u].fa = fa;
}
splay(u, 0);
}
inline void remove(int k){
int x = check(k, 0), y = check(k, 1);
splay(x, 0), splay(y, x);
int del = ls(y);
if(t[del].cnt > 1) t[del].cnt--, splay(del, 0);
else t[y].ch[0] = 0;
pushup(y), pushup(x);
}
inline int get_val(int x, int k){
pushdown(x);
if(k <= t[ls(x)].siz) return get_val(ls(x), k);
else if(k > t[ls(x)].siz + t[x].cnt) return get_val(rs(x), k - t[ls(x)].siz - t[x].cnt);
return t[x].val;
}
inline void solve(int l, int r){
l = get_val(root, l), r = get_val(root, r + 2);
splay(l, 0), splay(r, l);
t[ls(rs(root))].rev ^= 1;
}
inline void print(int x){
pushdown(x);
if(ls(x)) print(ls(x));
if(t[x].val > 1 && t[x].val < n + 2) write(t[x].val - 1), putchar(' ');
if(rs(x)) print(rs(x));
}
int main(){
n = read(), m = read();
for(int i = 1; i <= n + 2; ++i) insert(i);
for(int i = 1; i <= m; ++i){
int l = read(), r = read();
solve(l, r);
}
print(root), puts("");
return 0;
}
\[\_EOF\_
%%\]