【BBST 之伸展树 (Splay Tree)】
最近“hiho一下”出了平衡树专题,这周的Splay一直出现RE,应该删除操作指针没处理好,还没找出原因。
不过其他操作运行正常,尝试用它写了一道之前用set做的平衡树的题http://codeforces.com/problemset/problem/675/D,运行效果居然还挺好的,时间快了大概10%,内存少了大概30%。
1 #include <cstdio> 2 #include <cstring> 3 #include <string> 4 #include <cstdlib> 5 #include <cctype> 6 #include <cmath> 7 #include <algorithm> 8 #include <vector> 9 #include <map> 10 #include <set> 11 #include <stack> 12 #include <queue> 13 #include <assert.h> 14 #define FREAD(fn) freopen((fn), "r", stdin) 15 #define RINT(vn) scanf("%d", &(vn)) 16 #define PINT(vb) printf("%d", vb) 17 #define RSTR(vn) scanf("%s", (vn)) 18 #define PSTR(vn) printf("%s", (vn)) 19 #define CLEAR(A, X) memset(A, X, sizeof(A)) 20 #define REP(N) for(i=0; i<(N); i++) 21 #define REPE(N) for(i=1; i<=(N); i++) 22 #define pb(X) push_back(X) 23 #define pn() printf("\n") 24 using namespace std; 25 const int MAX_N = 100005; 26 const int MAX_K = 0x7fffffff; 27 const int MIN_K = 0; 28 29 int a[MAX_N]; 30 int n; 31 int i; 32 map<int, bool> left, right;//iostream里有和left, right冲突的命名! 33 34 struct Node 35 { 36 int k; 37 Node *l, *r, *p; 38 Node():k(-1), l(NULL), r(NULL), p(NULL){} 39 Node(int kk, Node* pp):k(kk), l(NULL), r(NULL), p(pp){} 40 ~Node(){ 41 l = r = p = NULL; 42 } 43 }; 44 45 struct Splay 46 { 47 Node* root; 48 Node* _hot; 49 Splay():root(NULL), _hot(NULL){} 50 Splay(int k):root(new Node(k, NULL)), _hot(root){} 51 52 void release(Node* cur){//释放子树cur的空间 53 if(cur == NULL) return ;//空树 54 release(cur->l); 55 cur->l = NULL; 56 release(cur->r); 57 cur->r = NULL; //不用加吧,cur马上就销毁了啊 58 //printf("deleted %d\n", cur->k); 59 60 delete cur; 61 return ; 62 } 63 ~Splay(){ 64 release(root); 65 root = NULL; 66 } 67 void zig(Node* cur){ 68 if(cur == NULL) return ; 69 Node* v = cur->l; 70 if(v == NULL) return ; 71 Node* g = cur->p; 72 73 v->p = g; 74 if(g != NULL) 75 //祖先g与v连接 76 (cur == g->l) ? g->l = v : g->r = v; 77 78 //v与cur孩子过继 79 cur->l = v->r; 80 if(cur->l != NULL) cur->l->p = cur; 81 82 //v与cur角色转换 83 cur->p = v; 84 v->r = cur; 85 if(cur == root) root = v; 86 //printf("%d zigged\n", cur->k); 87 } 88 void zag(Node* cur){ 89 if(cur == NULL) return ; 90 Node* v = cur->r; 91 if(v == NULL) return ; 92 Node* g = cur->p; 93 //printf("g=%d cur=%d v=%d\n", g->k, cur->k, v->k); 94 95 v->p = g; 96 if(g != NULL) 97 (cur == g->l) ? g->l = v : g->r = v; 98 99 cur->r = v->l; 100 if(cur->r != NULL) cur->r->p = cur; 101 102 cur->p = v; 103 v->l = cur; 104 if(cur == root) root = v; 105 //printf("%d zagged\n", cur->k); 106 } 107 void splay(Node* x, Node* f){// make x become f's child 108 if(x == NULL) return ; 109 while(x->p != f){//逐步双层伸展 110 Node* p = x->p; 111 if(p == NULL) return ; 112 if(p->p == f) 113 (x == p->l) ? zig(p) : zag(p); 114 else{ 115 Node* g = p->p; 116 if(g == NULL) return ; 117 if(g->l == p){ 118 if(p->l == x){ 119 zig(g); zig(p); 120 }else{ 121 zag(p); zig(g); 122 } 123 }else{ 124 if(p->l == x){ 125 zig(p); zag(g); 126 }else{ 127 zag(g); zag(p); 128 } 129 } 130 } 131 } 132 } 133 Node* search(Node* cur, int k){//在cur子树中查找关键码k 134 if(cur == NULL) return _hot;//查找失败还伸展吗?暂不伸展,待决定插入后再将新插入的节点伸展 135 if(cur->k == k){//查找成功 136 //printf("has %d\n", cur->k); 137 splay(cur, NULL);//将目标节点伸展至根 138 return cur; 139 } 140 _hot = cur;//需要深入子树查找 141 return (k < cur->k) ? search(cur->l, k) : search(cur->r, k); 142 } 143 Node* insert(Node* cur, int k){//将关键码k插入cur子树 144 if(cur == NULL){//找到目标插入位置 145 cur = new Node(k, _hot); 146 //printf("%d %d\n", _hot->k, k); 147 (k < _hot->k) ? _hot->l=cur : _hot->r=cur; 148 _hot = cur; 149 //printf("create %d\n", cur->k); 150 splay(cur, NULL);//将目标节点伸展至树根 151 return cur; 152 } 153 assert(cur); 154 _hot = cur;//进入子树 155 //printf("enter %d\n", cur->k); 156 return (k < cur->k) ? insert(cur->l, k) : insert(cur->r, k);//assert:关键码互异 157 } 158 Node* prev(int k){//寻找关键码k的中序前驱 159 splay(search(root, k), NULL);//将k伸展至树根 160 Node* cur = root->l;//根节点的左子树 161 //assert(cur); 162 if(!cur) return NULL; 163 while(cur->r != NULL) cur = cur->r;//前驱必然为左子树的最右节点 164 return cur; 165 } 166 Node* succ(int k){//寻找关键码k的中序后继, assert:k一定存在 167 splay(search(root, k), NULL); 168 Node* cur = root->r; 169 //assert(cur); 170 if(!cur) return NULL; 171 while(cur->l != NULL) cur = cur->l; 172 return cur; 173 } 174 void deleteK(int k){//删除关键码k 175 Node* p = prev(k); 176 Node* s = succ(k); 177 splay(p, NULL); 178 splay(s, p); 179 Node* q = s->l; 180 s->l = NULL;//解除父子关系 181 release(q);//释放子树空间,这里只有一个节点k 182 } 183 void deleteInterval(int a, int b){//删除区间[a,b]内的关键码 184 Node* pa = search(root, a);//pa为最后一个被访问的节点,必不空 185 assert(pa); 186 if(pa->k != a) pa = insert(pa, a);//查找失败,插入 187 //printf("pa->k = a = %d\n", pa->k); 188 189 Node* pb = search(root, b); 190 assert(pb); 191 //printf("pb->k = b = %d\n", pb->k); 192 if(pb->k != b) pb = insert(pb, b); 193 194 Node* p = prev(a); 195 assert(p); 196 Node* s = succ(b);//assert: p, s not null 197 assert(s); 198 //printf("prev %d succ %d\n", p->k, s->k); 199 splay(p, NULL); 200 //printf("%d splayed\n", p->k); 201 splay(s, p); 202 //printf("%d splayed\n", s->k); 203 Node* q = s->l; 204 _hot = s; 205 release(q);//释放子树空间 206 s->l = NULL; 207 } 208 }; 209 210 int main() 211 { 212 FREAD("675d.txt"); 213 RINT(n); 214 REP(n) RINT(a[i]); 215 Splay mySplay(a[0]); 216 for(i=1; i<n; i++){ 217 // Node* p = mySplay.search(mySplay.root, a[i]);//必然失败 218 // if(p->k > a[i]) 219 // printf("%d\n", mySplay.prev(p->k)); 220 // else printf("%d\n", p->k); 221 int ans = 0; 222 Node* q = mySplay.insert(mySplay.root, a[i]); 223 Node* p = mySplay.prev(a[i]); 224 if(p && right.count(p->k) == 0){ 225 right[p->k] = 1;//前驱没有右孩子 226 ans = p->k; 227 }else{ 228 Node* s = mySplay.succ(a[i]); 229 if(s && left.count(s->k) == 0){ 230 left[s->k] = 1; 231 ans = s->k; 232 } 233 } 234 printf("%d\n", ans); 235 } 236 return 0; 237 }
再次做这道题,我对BST的认识更清晰了一些,在此梳理一下:
1. 首先,我们把中序遍历序列相同的二叉搜索树互称“等价BST”;可以看出,对于一个中序遍历序列,可以画出若干棵拓扑结构(即祖先后代的关系)不同的等价BST,且它们可以通过一些列“等价变换”而互相转换。常用的“等价变换”方法有我们熟悉的“旋转调整”,如下图(引自数据结构的课件):
可以这样来记忆:zig是顺时针(clockwise)方向,zag是逆时针(anti-clockwise)方向;
而zig(p)或zag(p)可以形象地看成是“把p压下来,把它的孩子翘上去”。(注:hihocoder的教程和我的记法不太一样,习惯一种就好,不要搞混)
这样的zig/zag局部拓扑结构调整,在实现步骤上可以分如下三步走,代码上面有。
(1)v与p的祖先g建立连接 --->(2)v的孩子Y过继给p ---> (3)v与p角色互换
2. 平衡二叉搜索树在进行旋转调整时,树的拓扑结构发生了改变,而且这种改变如果不额外记录信息的话,是没办法直接从结果拓扑反推原始拓扑的。因为每一步调整都有若干种可能,纵使现成的红黑树set对外提供了父节点、孩子节点的接口,所返回的也是变换后的当前拓扑结构,无法直接得出原始拓扑。
3. 此题求的是原始拓扑结构中每次所插入的节点的父节点,数据范围10^5,不能承受退化情况的复杂度,故不可直接模拟,必须用平衡树来维护真实数据。而拓扑结构这一性质随着我们的旋转调整而发生了改变,所以必须想办法把原始拓扑结构记下来。而具体需要记录什么呢?
(1)经过上次博客的分析,我们发现节点v的父节点必然是其中序遍历的直接前驱或直接后继。我们先假设前驱和后继这个信息可以方便地得到,那么如何判定究竟是前驱还是后继呢?这个便是问题的关键所在了。由上次博客的结论,在原始拓扑中,若v作为前驱p的右孩子插入,则插入前p的右孩子必为空;与之对称的是s的左孩子为空的情况。那么我们只需记录原始拓扑中每个节点是否有左孩子、右孩子这一信息即可。因此用两个数组即可。不过这道题节点数值范围为10^9开不下,所以用了map。这个做法来自题目作者的题解。
(2)再来考虑如何确定v的前驱后继:无论怎么调整,等价BST的中序遍历序列是不变的,故节点v任意时刻的直接前驱和直接后继也不会改变,所以对前驱和后继的信息我们无需额外维护,而可以在任意时刻根据当前拓扑在O(logn)时间内求得。
说了这么多,一直在分析这道题,还没说到Splay。。。
Splay Tree(伸展树)是BBST家族中一种很“潇洒”的数据结构,实际上,它在很多情况下整体上并不处于“平衡状态”,即不能保证对所有节点的访问控制在O(logn)。但是它的设计很有现实意义,(最近复习计算机系统结构,发现它的设计理念十分符合“以经常性时间为重点”和“局部性原理”这两条系统结构设计的定量原理)。下面来具体看下它的设计思路:
1. 对于传统的平衡搜索树如AVL树,它假定每次对每个节点的访问是等概率的,所以每次动态操作后都“小心翼翼”地维护着平衡因子,从而使得最深的节点的访问成本也能控制在O(logn)。而现实情况中对数据的访问通常并不是等概率的,相反,它常具有局部性;在关注吞吐率而不是单次操作的场合,这一问题更为突出。
2. 这里我们转而关注连续访问一批数据的总体时间。要想总体时间少,我们采用贪心策略,让越经常被访问的节点访问成本越小,这一点和哈夫曼编码的设计如出一辙;然而此处的每个节点的“经常性”是动态变化的,可以有一部分历史数据作为“经验”,但通常要把每次访问的情况吸收到经验中作为下次调整的参考。
3. 由此得出Splay Tree的设计思路:按照“经常性”动态调整拓扑结构,一种实现方法就是:每次把刚刚访问过的节点调整到树根。
另,用教科书上的话说,这是一种“即用即调整的启发式策略”,是“自调整链表”的一种推广。(记得严蔚敏版《数据结构题集》线性表一章出现过“自调整链表”,当时我还傻傻地叫它“频度伸展链表”(https://github.com/helenawang/ywmDS/blob/master/LinkList/FreqList.c))
具体调整的实现,即伸展操作:和AVL树一样,伸展树的调整也分4种情况,而且其中“之字形”的两种的调整与AVL的双旋完全一致;
不同在于如下左图(再次盗用了数据结构的课件)这种“三代同侧”的情况:
(1)如左图上部,AVL通常只需做一次zig(p)的单旋便达到了局部的平衡;若想把v调成最高代,还要再做一次zig(g)的单旋;
(2)如左图下部,伸展树先做了一次zig(g),再做一次zig(p),把三代从“一边倒”调成了“另一边倒”;
(3)这两种做法的区别可从右图直观感受到,双层调整可以“折叠”沿途节点,从而降低树高。
注意伸展操作不只发生在动态操作后,每次查找操作也要进行伸展。对于删除操作,"hiho一下"的教程给出的做法很巧妙,即:找到待删除节点的前驱p和后继s,然后先把p伸展至树根,再把s伸展至p的右子树,至此,待删除节点(区间删除也可以)必然位于s的左子树,把左子树摘除并释放空间即可。代码如上,但是我提交后出现RE,原因尚不明,可能是释放空间后指针没置空。
hiho一下第104周 用java写的版本,可以AC
1 import java.util.*; 2 import java.util.Scanner; 3 4 public class Main{ 5 static int MIN_K = 0; 6 static int MAX_K = 1000000005; 7 public static void main(String[] args) { 8 Splay mySplay = new Splay(MIN_K); 9 mySplay.insert(mySplay.root, MAX_K); 10 11 Scanner in = new Scanner(System.in); 12 int n = in.nextInt(); 13 String c; 14 int k, a, b; 15 for(int i=0; i<n; i++){ 16 c = in.next(); 17 switch(c){ 18 case "I": 19 k = in.nextInt(); 20 mySplay.insert(mySplay.root, k); 21 break; 22 case "Q": 23 k = in.nextInt(); 24 Node t = mySplay.search(mySplay.root, k); 25 if(k < t.k) t = mySplay.prev(k); 26 System.out.println(t.k); 27 break; 28 case "D": 29 a = in.nextInt(); b = in.nextInt(); 30 mySplay.deleteInterval(a, b); 31 break; 32 default: break; 33 } 34 } 35 } 36 } 37 38 class Splay { 39 Node root; 40 Node _hot; 41 Splay(){ 42 root = _hot = null; 43 } 44 Splay(int k){ 45 root = new Node(k, null); 46 _hot = root; 47 } 48 void zig(Node cur){ 49 if(cur == null) return ; 50 Node v = cur.l; 51 if(v == null) return ; 52 Node g = cur.p; 53 54 v.p = g; 55 if(g != null){ 56 if(cur == g.l) g.l = v; 57 else g.r = v; 58 } 59 60 cur.l = v.r; 61 if(cur.l != null) cur.l.p = cur; 62 63 cur.p = v; 64 v.r = cur; 65 if(cur == root) root = v; 66 } 67 void zag(Node cur){ 68 if(cur == null) return ; 69 Node v = cur.r; 70 if(v == null) return ; 71 Node g = cur.p; 72 73 v.p = g; 74 if(g != null){ 75 if(cur == g.l) g.l = v; 76 else g.r = v; 77 } 78 79 cur.r = v.l; 80 if(cur.r != null) cur.r.p = cur; 81 82 cur.p = v; 83 v.l = cur; 84 if(cur == root) root = v; 85 } 86 void splay(Node x, Node f){ 87 if(x == null) return ; 88 while(x.p != f){ 89 Node p = x.p; 90 if(p == null) return ; 91 if(p.p == f){ 92 if(x == p.l){ 93 94 zig(p); 95 } 96 else{ 97 zag(p); 98 } 99 }else{ 100 Node g = p.p; 101 if(g == null) return ; 102 if(g.l == p){ 103 if(p.l == x){ 104 zig(g); zig(p); 105 }else{ 106 zag(p); zig(g); 107 } 108 }else{ 109 if(p.l == x){ 110 zig(p); zag(g); 111 }else{ 112 zag(g); zag(p); 113 } 114 } 115 } 116 } 117 } 118 Node search(Node cur, int k){ 119 if(cur == null) return _hot; 120 if(cur.k == k){ 121 splay(cur, null); 122 return cur; 123 } 124 _hot = cur; 125 if(k < cur.k) return search(cur.l, k); 126 else return search(cur.r, k); 127 } 128 Node insert(Node cur, int k){ 129 if(cur == null){ 130 cur = new Node(k, _hot); 131 if(k < _hot.k) _hot.l = cur; 132 else _hot.r = cur; 133 _hot = cur; 134 //System.out.println("find place"); 135 splay(cur, null); 136 //System.out.println(cur.k + "created"); 137 return cur; 138 } 139 _hot = cur; 140 if(k < cur.k) return insert(cur.l, k); 141 else return insert(cur.r, k); 142 } 143 Node prev(int k){ 144 splay(search(root, k), null); 145 Node cur = root.l; 146 if(cur == null) return null; 147 while(cur.r != null) cur = cur.r; 148 return cur; 149 } 150 Node succ(int k){ 151 splay(search(root, k), null); 152 Node cur = root.r; 153 if(cur == null) return null; 154 while(cur.l != null) cur = cur.l; 155 return cur; 156 } 157 void deleteInterval(int a, int b){ 158 Node pa = search(root, a); 159 if(pa.k != a) pa = insert(pa, a); 160 Node pb = search(root, b); 161 if(pb.k != b) pb = insert(pb, b); 162 163 Node p = prev(a); 164 Node s = succ(b); 165 splay(p, null); 166 splay(s, p); 167 Node q = s.l; 168 _hot = s; 169 s.l = null; 170 } 171 } 172 class Node{ 173 int k; 174 Node l, r; 175 Node p; 176 Node(){p = null;} 177 Node(int kk, Node pp){ 178 k = kk; 179 p = pp; 180 l = r = null; 181 } 182 }