【模板】【伸展树】Splay Tree小结
先发两个模板,都是网上找的略作修改后的。
1.普通版。支持lazy操作、重复值、找前驱后继、单个删除、值区间删除等。代码对应的题是BZOJ1588 - 营业额统计。
1 #include<bits/stdc++.h> 2 using namespace std; 3 const int MAXN = 100011; 4 5 struct SplayTree{ 6 int cnt, rt; 7 int Add[MAXN]; 8 9 struct Node{ 10 int key, cnt, size, fa, son[2]; 11 }T[MAXN]; 12 13 inline void PushUp(int x){ 14 T[x].size=T[T[x].son[0]].size+T[T[x].son[1]].size+T[x].cnt; 15 } 16 17 inline void PushDown(int x){ 18 if(Add[x]){ 19 if(T[x].son[0]){ 20 T[T[x].son[0]].key+=Add[x]; 21 Add[T[x].son[0]]+=Add[x]; 22 } 23 if(T[x].son[1]){ 24 T[T[x].son[1]].key+=Add[x]; 25 Add[T[x].son[1]]+=Add[x]; 26 } 27 Add[x]=0; 28 } 29 } 30 31 inline int Newnode(int key, int fa){ //新建一个节点并返回 32 ++cnt; 33 T[cnt].key=key; 34 T[cnt].cnt=T[cnt].size=1; 35 T[cnt].fa=fa; 36 T[cnt].son[0]=T[cnt].son[1]=0; 37 return cnt; 38 } 39 40 inline void Rotate(int x, int p){ //0左旋 1右旋 41 int y=T[x].fa; 42 PushDown(y); 43 PushDown(x); 44 T[y].son[!p]=T[x].son[p]; 45 T[T[x].son[p]].fa=y; 46 T[x].fa=T[y].fa; 47 if(T[x].fa) 48 T[T[x].fa].son[T[T[x].fa].son[1] == y]=x; 49 T[x].son[p]=y; 50 T[y].fa=x; 51 PushUp(y); 52 PushUp(x); 53 } 54 55 void Splay(int x, int to){ //将x节点移动到To的子节点中 56 while(T[x].fa != to){ 57 if(T[T[x].fa].fa == to) 58 Rotate(x, T[T[x].fa].son[0] == x); 59 else{ 60 int y=T[x].fa, z=T[y].fa; 61 int p=(T[z].son[0] == y); 62 if(T[y].son[p] == x) 63 Rotate(x, !p), Rotate(x, p); //之字旋 64 else 65 Rotate(y, p), Rotate(x, p); //一字旋 66 } 67 } 68 if(to == 0) rt=x; 69 } 70 71 int GetKth(int k){ 72 if(!rt || k > T[rt].size) return -1e9; // 若要节点id,改为0 73 int x=rt; 74 while(x){ 75 PushDown(x); 76 if(k >= T[T[x].son[0]].size+1 && k <= T[T[x].son[0]].size+T[x].cnt) 77 break; 78 if(k > T[T[x].son[0]].size+T[x].cnt){ 79 k-=T[T[x].son[0]].size+T[x].cnt; 80 x=T[x].son[1]; 81 } 82 else 83 x=T[x].son[0]; 84 } 85 return T[x].key; // 若要节点id,改为x 86 } 87 88 int Find(int key){ //返回值为key的节点 若无返回0 若有将其转移到根处 89 if(!rt) return 0; 90 int x=rt; 91 while(x){ 92 PushDown(x); 93 if(T[x].key == key) break; 94 x=T[x].son[key > T[x].key]; 95 } 96 if(x) Splay(x, 0); 97 return x; 98 } 99 100 int Prev(){ //返回根节点的前驱 非重点 101 if(!rt || !T[rt].son[0]) return 0; 102 int x=T[rt].son[0]; 103 while(T[x].son[1]){ 104 PushDown(x); 105 x=T[x].son[1]; 106 } 107 Splay(x, 0); 108 return x; 109 } 110 111 int Succ(){ //返回根结点的后继 非重点 112 if(!rt || !T[rt].son[1]) return 0; 113 int x=T[rt].son[1]; 114 while(T[x].son[0]){ 115 PushDown(x); 116 x=T[x].son[0]; 117 } 118 Splay(x, 0); 119 return x; 120 } 121 122 void Insert(int key){ //插入key值 123 if(!rt) 124 rt=Newnode(key, 0); 125 else{ 126 int x=rt, y=0; 127 while(x){ 128 PushDown(x); 129 y=x; 130 if(T[x].key == key){ 131 T[x].cnt++; 132 T[x].size++; 133 break; 134 } 135 T[x].size++; 136 x=T[x].son[key > T[x].key]; 137 } 138 if(!x) 139 x=T[y].son[key > T[y].key]=Newnode(key, y); 140 Splay(x, 0); 141 } 142 } 143 144 void Delete(int key){ //删除值为key的节点1个 145 int x=Find(key); 146 if(!x) return; 147 if(T[x].cnt>1){ 148 T[x].cnt--; 149 PushUp(x); 150 return; 151 } 152 int y=T[x].son[0]; 153 while(T[y].son[1]) 154 y=T[y].son[1]; 155 int z=T[x].son[1]; 156 while(T[z].son[0]) 157 z=T[z].son[0]; 158 if(!y && !z){ 159 rt=0; 160 return; 161 } 162 if(!y){ 163 Splay(z, 0); 164 T[z].son[0]=0; 165 PushUp(z); 166 return; 167 } 168 if(!z){ 169 Splay(y, 0); 170 T[y].son[1]=0; 171 PushUp(y); 172 return; 173 } 174 Splay(y, 0); 175 Splay(z, y); 176 T[z].son[0]=0; 177 PushUp(z); 178 PushUp(y); 179 } 180 181 int GetRank(int key){ //获得值<=key的节点个数 182 if(!Find(key)){ 183 Insert(key); 184 int tmp=T[T[rt].son[0]].size; 185 Delete(key); 186 return tmp; 187 } 188 else 189 return T[T[rt].son[0]].size+T[rt].cnt; 190 } 191 192 void Delete(int l, int r){ //删除值在[l, r]中的所有节点 l!=r 193 if(!Find(l)) Insert(l); 194 int p=Prev(); 195 if(!Find(r)) Insert(r); 196 int q=Succ(); 197 if(!p && !q){ 198 rt=0; 199 return; 200 } 201 if(!p){ 202 T[rt].son[0]=0; 203 PushUp(rt); 204 return; 205 } 206 if(!q){ 207 Splay(p, 0); 208 T[rt].son[1]=0; 209 PushUp(rt); 210 return; 211 } 212 Splay(p, q); 213 T[p].son[1]=0; 214 PushUp(p); 215 PushUp(q); 216 } 217 218 int solve(int key){ 219 if(!rt) return 0; 220 int x = rt, res = 2e9; 221 while(x){ 222 res = min(res,abs(T[x].key-key)); 223 x = T[x].son[key > T[x].key]; 224 } 225 return res; 226 } 227 }spt; 228 229 int n,x,ans; 230 231 int main(){ 232 cin >> n >> x; 233 spt.Insert(x); 234 ans = x; 235 for(int i = 1;i < n;++i){ 236 scanf("%d",&x); 237 ans += spt.solve(x); 238 spt.Insert(x); 239 } 240 cout << ans << endl; 241 242 return 0; 243 }
2.区间操作版。支持区间更新、区间查询、区间翻转,也可以很轻松地再添加区间切割、插入等。代码对应的题是BZOJ1251 - 序列终结者。
1 /*bzoj 1251 序列终结者 2 题意: 3 给定一个长度为N的序列,每个序列的元素是一个整数。要支持以下三种操作: 4 1. 将[L,R]这个区间内的所有数加上V; 5 2. 将[L,R]这个区间翻转,比如1 2 3 4变成4 3 2 1; 6 3. 求[L,R]这个区间中的最大值; 7 最开始所有元素都是0。 8 限制: 9 N <= 50000, M <= 100000 10 思路: 11 伸展树 12 13 关键点: 14 1. 伸展树为左小右大的二叉树,所以旋转操作不会影响树的性质 15 2. 区间操作为: 16 int u = select(L - 1), v = select(R + 1); 17 splay(u, 0); splay(v, u); //通过旋转操作把询问的区间聚集到根的右子树的左子树下 18 因为伸展树为左小右大的二叉树,旋转操作后的所以对于闭区间[L, R]之间的所有元素都聚集在根的右子树的左子树下 19 因为闭区间[L, R], 20 1) 所以每次都要查开区间(L - 1, R + 1), 21 2) 所以伸展树元素1对应的标号为2, 22 3) 所以node[0]对应空节点,node[1]对应比所以元素标号都小的点,node[2 ~ n + 1]对应元素1 ~ n,node[n + 2]对应比所有元素标号都打的点,其中node[0], node[1], node[n + 2]都是虚节点,不代表任何元素。 23 */ 24 #include<bits/stdc++.h> 25 using namespace std; 26 27 #define LS(n) node[(n)].ch[0] 28 #define RS(n) node[(n)].ch[1] 29 30 const int N = 1e5 + 5; 31 const int INF = 0x3f3f3f3f; 32 struct Splay { 33 struct Node{ 34 int fa, ch[2]; 35 bool rev; 36 int val, lazy, mx, size; 37 void init(int _val) { 38 val = mx = _val; 39 size = 1; 40 lazy = rev = ch[0] = ch[1] = 0; 41 } 42 } node[N]; 43 int root; 44 45 void pushup(int n) { 46 node[n].mx = max(node[n].val, max(node[LS(n)].mx, node[RS(n)].mx)); 47 node[n].size = node[LS(n)].size + node[RS(n)].size + 1; 48 } 49 50 void pushdown(int n) { 51 if(n == 0) return ; 52 if(node[n].lazy) { 53 if(LS(n)) { 54 node[LS(n)].val += node[n].lazy; 55 node[LS(n)].mx += node[n].lazy; 56 node[LS(n)].lazy += node[n].lazy; 57 } 58 if(RS(n)) { 59 node[RS(n)].val += node[n].lazy; 60 node[RS(n)].mx += node[n].lazy; 61 node[RS(n)].lazy += node[n].lazy; 62 } 63 node[n].lazy = 0; 64 } 65 if(node[n].rev) { 66 if(LS(n)) node[LS(n)].rev ^= 1; 67 if(RS(n)) node[RS(n)].rev ^= 1; 68 swap(LS(n), RS(n)); 69 node[n].rev = 0; 70 } 71 } 72 73 void rotate(int n, bool kind) { 74 int fn = node[n].fa; 75 int ffn = node[fn].fa; 76 node[fn].ch[!kind] = node[n].ch[kind]; 77 node[node[n].ch[kind]].fa = fn; 78 79 node[n].ch[kind] = fn; 80 node[fn].fa = n; 81 82 node[ffn].ch[RS(ffn) == fn] = n; 83 node[n].fa = ffn; 84 pushup(fn); 85 } 86 87 //旋转到goal的儿子处 88 void splay(int n, int goal) { 89 pushdown(n); 90 while(node[n].fa != goal) { 91 int fn = node[n].fa; 92 int ffn = node[fn].fa; 93 pushdown(ffn); pushdown(fn); pushdown(n); 94 bool rotate_n = (LS(fn) == n); 95 bool rotate_fn = (LS(ffn) == fn); 96 if(ffn == goal) rotate(n, rotate_n); 97 else { 98 if(rotate_n == rotate_fn) rotate(fn, rotate_fn); 99 else rotate(n, rotate_n); 100 rotate(n, rotate_fn); 101 } 102 } 103 pushup(n); 104 if(goal == 0) root = n; 105 } 106 107 int select(int pos) { 108 int u = root; 109 pushdown(u); 110 while(node[LS(u)].size != pos) { 111 if(pos < node[LS(u)].size) 112 u = LS(u); 113 else { 114 pos -= node[LS(u)].size + 1; 115 u = RS(u); 116 } 117 pushdown(u); 118 } 119 return u; 120 } 121 122 int build(int L, int R) { 123 if(L > R) return 0; 124 if(L == R) return L; 125 int mid = (L + R) >> 1; 126 int r_L, r_R; 127 LS(mid) = r_L = build(L, mid - 1); 128 RS(mid) = r_R = build(mid + 1, R); 129 node[r_L].fa = node[r_R].fa = mid; 130 pushup(mid); 131 return mid; 132 } 133 134 void init(int n) { 135 node[0].init(-INF); node[0].size = 0; 136 node[1].init(-INF); 137 node[n + 2].init(-INF); 138 for(int i = 2; i <= n + 1; ++i) 139 node[i].init(0); 140 141 root = build(1, n + 2); 142 node[root].fa = 0; 143 144 node[0].fa = 0; 145 LS(0) = root; 146 } 147 148 void solve(int type,int l,int r,int val){ 149 int u = select(l-1), v = select(r+1); 150 splay(u,0);splay(v,u); 151 if(type == 1){ // Update 152 node[LS(v)].val += val; 153 node[LS(v)].mx += val; 154 node[LS(v)].lazy += val; 155 } 156 else if(type == 2) // Reverse 157 node[LS(v)].rev ^= 1; 158 else // Query 159 printf("%d\n",node[LS(v)].mx); 160 } 161 } spt; 162 163 int main() { 164 int n, m; 165 scanf("%d%d", &n, &m); 166 spt.init(n); 167 for(int i = 0; i < m; ++i) { 168 int op,l,r,v; 169 scanf("%d%d%d",&op,&l,&r); 170 if(op == 1) scanf("%d",&v); 171 spt.solve(op,l,r,v); 172 } 173 return 0; 174 }
BZOJ1503 - 郁闷的出纳员(插入、删除、查第k大):
思路:模板题,用第一个模板就好了。关键代码如下。
void add(int val){ int u = Find(-1e9), v = Find(1e9); Splay(u,0); Splay(v,u); Add[T[v].son[0]] += val; T[T[v].son[0]].key += val; } }spt; int n,m,cnt; int main(){ cin >> n >> m; spt.Insert(-1e9); spt.Insert(1e9); while(n--){ char op[3]; int k; scanf("%s%d",op,&k); if(op[0] == 'I'){ if(k >= m) spt.Insert(k); } if(op[0] == 'A') spt.add(k); if(op[0] == 'S'){ spt.add(-k); cnt += spt.GetRank(m-1) - 1; spt.Delete(-5e8,m-1); } if(op[0] == 'F'){ int t = spt.GetKth(k+1); if(t < -5e8) printf("-1\n"); else printf("%d\n",t); } } printf("%d\n",cnt); return 0; }
HDU3487 - Play with chain(区间翻转、切割):
题意:开始有一个1, 2, 3,... , n的序列,进行m次操作,CUT a b c将区间[a,b]取出得到新序列,将区间插入到新序列第c个元素之后,FLIP a b 将区间[a,b]翻转。输出最终的序列。
思路:对于CUT操作我们需要先提取出区间[a,b],然后删除,然后以c为边界分割为左右两部分,再合并c左边的区间和提取出来的区间[a,b],随后将最右的旋转至根(注意在第二个模板中,size是把初始化时的左右2个虚节点算进去了的),然后和c右边的区间合并。对于FLIP操作,我们可以进行lazy操作,先打个标记但不翻转,需要的时候再翻转。关键代码如下。
void solve(int type,int l,int r,int c){ int u = select(l-1), v = select(r+1); splay(u,0);splay(v,u); if(type == 1) /**< Reverse */ node[LS(v)].rev ^= 1; else{ /**< Cut */ int rt1 = LS(v); LS(v) = 0; // Delete [l,r] pushup(v); pushup(u); u = select(c); splay(u,0); int rt2 = RS(u); RS(u) = rt1; // Merge [1,c] with [l,r] node[rt1].fa = u; pushup(u); u = select(node[root].size-1); splay(u,0); RS(u) = rt2; // Merge node[rt2].fa = u; pushup(u); } } void traverse(int x){ if(!x) return; pushdown(x); traverse(LS(x)); ans.push_back(node[x].val); traverse(RS(x)); }