【bzoj3678】wangxz与OJ Splay
题目描述
给你一个序列,支持三种操作:
$0\ p\ a\ b$ :在序列的第 $p$ 个数后面插入 $a,a+1,...,b$ ;
$1\ a\ b$ :删除序列第 $a,a+1,...,b$ 位置的数;
$2\ p$ :查询序列 $p$ 位置的数。
输入
输入第一行包括两个正整数n(1<=n<=20000),m(1<=m<=20000),代表初始序列元素个数和操作个数。
接下来n个整数,为初始序列元素。
接下来m行,每行第一个为整数sym,
若sym=0,接下来有一个非负整数p,两个整数a,b;
若sym=1,接下来有两个正整数a,b;
若sym=2,接下来有一个正整数p;
p、x、y的含义及范围见题目描述。
在任何情况下,保证序列中的元素总数不超过100000。
保证题目涉及的所有数在int内。
输出
对每个sym=2,输出一行,包括一个整数,代表询问位置的元素。
样例输入
5 3
1 2 3 4 5
0 2 1 4
1 3 8
2 2
样例输出
2
题解
Splay
定点插入、区间删除,显然使用Splay维护。
然而这里有一个问题:插入操作是一段数一起插入的,单个插入的话时间会爆炸。
考虑到插入的数都是连续的等差数列,因此可以令Splay中每个节点代表一个等差数列。查询第 $k$ 个数时,如果找到的节点是等差数列,则将其分裂为所找数、左半部分及右半部分。
这样每次find操作最多新增两个节点,复杂度有了保证。
时间复杂度 $O(n\log n)$ 。
注意空间大小的问题,每次插入操作最多新增5个节点(2*2分裂+1插入),因此不内存回收的话数组至少要开到12W。
#include <cstdio> #define N 120010 int fa[N] , c[2][N] , si[N] , w[N] , pos[N] , tot , root; inline void pushup(int x) { si[x] = si[c[0][x]] + si[c[1][x]] + w[x]; } int build(int l , int r) { if(l > r) return 0; int mid = (l + r) >> 1; c[0][mid] = build(l , mid - 1) , fa[c[0][mid]] = mid; c[1][mid] = build(mid + 1 , r) , fa[c[1][mid]] = mid; w[mid] = 1 , pushup(mid); return mid; } inline void rotate(int &k , int x) { int y = fa[x] , z = fa[y] , l = (c[1][y] == x) , r = l ^ 1; if(y == k) k = x; else c[c[1][z] == y][z] = x; fa[x] = z , fa[y] = x , fa[c[r][x]] = y , c[l][y] = c[r][x] , c[r][x] = y; pushup(y) , pushup(x); } inline void splay(int &k , int x) { int y , z; while(x != k) { y = fa[x] , z = fa[y]; if(y != k) { if((c[0][y] == x) ^ (c[0][z] == y)) rotate(k , x); else rotate(k , y); } rotate(k , x); } } int find(int k , int x) { if(x <= si[c[0][k]]) return find(c[0][k] , x); else if(x > si[c[0][k]] + w[k]) return find(c[1][k] , x - si[c[0][k]] - w[k]); x -= si[c[0][k]]; if(x > 1) pos[++tot] = pos[k] , w[tot] = x - 1 , fa[tot] = k , fa[c[0][k]] = tot , c[0][tot] = c[0][k] , c[0][k] = tot , pushup(tot); if(x < w[k]) pos[++tot] = pos[k] + x , w[tot] = w[k] - x , fa[tot] = k , fa[c[1][k]] = tot , c[1][tot] = c[1][k] , c[1][k] = tot , pushup(tot); w[k] = 1 , pos[k] += x - 1; return k; } inline int split(int l , int r) { int a = find(root , l) , b = find(root , r + 2); splay(root , a) , splay(c[1][root] , b); return c[0][c[1][root]]; } int main() { int n , m , i , opt , x , y , z; scanf("%d%d" , &n , &m) , tot = n + 2; for(i = 2 ; i <= n + 1 ; i ++ ) scanf("%d" , &pos[i]); root = build(1 , n + 2); while(m -- ) { scanf("%d%d" , &opt , &x); if(opt == 0) { scanf("%d%d" , &y , &z) , split(x + 1 , x); c[0][c[1][root]] = ++tot , pos[tot] = y , w[tot] = si[tot] = z - y + 1 , fa[tot] = c[1][root]; pushup(c[1][root]) , pushup(root); } else if(opt == 1) { scanf("%d" , &y) , z = split(x , y); fa[z] = c[0][c[1][root]] = 0 , pushup(c[1][root]) , pushup(root); } else printf("%d\n" , pos[split(x , x)]); } return 0; }