伸展树的基本操作——以【NOI2004】郁闷的出纳员为例
前两天老师讲了伸展树……虽然一个月以前自己就一直在看平衡树这一部分的书籍,也仔细地研读过伸展树地操作代码,但是就是没写过程序……(大概也是在平衡树的复杂操作和长代码面前望而生畏了)但是今天借着老师布置作业这个机会,加上hockey之前不厌其烦地手把手带着我写过四五遍Splay的神代码,是时候把它用程序实现出来了。
第一部分 伸展树基本概念和操作
伸展树是一种平衡二叉查找树,但是和其它平衡树(比如红黑树、AVL树)不同,它的节点上没有记录任何用于保持平衡的其他信息(AVL树上记录了节点的高度,而红黑树上记录了节点的颜色)。而更奇葩的地方在于,作为一棵平衡树,它在实际中可能并不平衡。当树中的节点数为N的时候,一次插入/查找操作的最坏复杂度是O(N),AVL树和红黑树试图通过修整树的结构(具体来说是高度)来避免这种最会情况的发生。而伸展树的思想并非如此——假如最坏情况的时间开销是O(N),那么我们不用尝试彻底避免它,而是只要它不要经常发生就好了。为了达到这一点,我们在每次访问一个结点之后将它通过一系列平衡树的旋转操作转到树根的位置。至于这么做的好处,我们可以通过概率和摊还分析中的势能方法证明进行M次操作的摊还时间界为O(MlogN),这个复杂度的证明过程十分地复杂,在这里不再赘述(我总是用这种方式跳过我不会的东西……学懂了之后一定补上)
在正式介绍各个操作之前先明确一下我们的节点定义形式(这些看似奇怪的定义方法在简化代码提高效率方面特别有用)
1 struct node 2 { 3 node *p, *s[2]; 4 int key, size; 5 node(){p = s[0] = s[1] = 0; size = 1; key = 0;} 6 node(int key) :key(key) {p = s[0] = s[1] = 0; size = 1;} 7 bool getlr(){return p->s[1] == this;} 8 node *link(int w, node *p){s[w] = p; if(p)p->p = this; return this;} 9 void update(){size = 1 + (s[0] ? s[0]->size : 0) + (s[1] ? s[1]->size : 0);} 10 } ;
node为我们的节点类,p和s[0]、s[1]均为指向node的指针类型,其中p表示当前节点的父节点,s[0]表示左儿子,s[1]表示右儿子
key表示当前节点的键值,size表示当前节点为根的子树的节点个数
接下来是两个构造函数,分别是无参数初始化和使用键值初始化,想必不用过多解释
getlr()返回一个布尔值,表示当前节点是它的父亲的哪个儿子(0为左,1为右)
link(int w, node *p)表示将节点p连接到当前节点的w孩子位置上(照例,0为左,1为右,下文中不再说明),并返回当前节点(hockey:一会你就能知道这个东西多么神)
update()表示更新当前节点的size,BTW,不熟悉C++ 中“?:”运算符的同学需要复习一下,因为接下来会多次使用
首先先复习一下平衡树的最基本操作——旋转
旋转操作这种东西最直观最容易理解的就是看图啦,接下来我们要用自然语言描述这个过程以便准确的写出程序
1.左旋:当p是它父亲的右儿子时执行左旋,将它的左儿子变为p的父亲的新的右儿子,而p原来的父亲变为p新的左儿子
2.右旋:当p是它父亲的左儿子时执行右旋,将它的右儿子变为p的父亲的新的左儿子,而p原来的父亲变为p新的右儿子
通俗一点来讲,以上图中的右旋为例,操作实际上就是将左图中的BE断开,在AE间链接一条边,然后用手捏着B让A自由下落,就变成了右图的情况,反之即是左旋。
另外很重要的一点,既然p的旋转方式完全取决于p->getlr()的值,那么我们完全没有必要把左旋和右旋分开成两个程序,下面上代码~
1 void rot(node *p) 2 { 3 node *q = p->p->p; 4 p->getlr() ? p->link(0, p->p->link(1, p->s[0])) : p->link(1, p->p->link(0, p->s[1])); 5 p->p->update(); 6 if(q)q->link(q->s[1] == p->p, p); else{p->p = 0; root = p;} 7 }
接下来开始解释代码——
首先先找到p的爷爷保存起来。
接下来一行就是重点!请牢记旋转操作的定义兵仔细读这行代码:如果p是右儿子,就将p的左儿子连接到p的父亲上,再将连接好的p的父亲连接到p的左儿子位置上,反之则是右旋。这一行充分体现了link()函数的优越性。
然后更新p的父亲(这时的p的父亲还是原来那个父亲,想一想为什么)
最后把p连到它父亲原来的位置上
至于为什么不需要更新p,答案是我们会在splay操作中一直调用rot把p转到树根,那时候再更新它也不迟
3.伸展:
顾名思义,伸展树的核心操作自然应该是伸展(splay)。splay(p)的定义很简单,为将p转到根的位置上,我们需要细致的考虑一下它的实现方式。
最简单的,也是最容易想到的方式就是不停地对p做旋转直至p被转到根上,但是这样会有一个问题——在某种情况下均摊时间界会被破坏,而得到这样的一组输入是轻而易举的[1]
下图很好地说明了这个例子(Ubuntu下画图不好用大家见谅):如果我们不停地旋转做左图中的1把它旋转到根的话,结果就会变成右图那样(在黑板上画一画)
显然如果我们继续访问2号节点,整棵树仍旧会像一条链,每次访问的复杂度都是O(N),这样我们的平衡树就没有意义了。
但是我们有没有什么方法来避免它?答案是肯定的。接下来我们将引入双旋的概念(将前文中的左旋右旋统称为单旋)
(1)一字形旋转:
这种旋转方式就是为了避免上文中那种情况的出现而创造的。它的自然语言说明如下:如果p和p的父亲同是它们的父亲的左儿子或者右儿子时,先对p的父亲进行一次单旋,再对p进行单旋,如下图所示
我们现在想要把7号节点转到根上,首先对6进行一次左旋
然后我们再对7做一次左旋,就变成了下图那样
从直观角度看来,这种旋转方式并没有任何优化效果——因为它把一条链变成了反过来的另外一条链,但是实则不然。在图中我们用一个等腰三角形来代表一棵子树,但是子树的尺寸并不一致——大家可以通过手推一下上文中8个节点的例子使用一字形双旋转的出色效果,这里不再赘述(图太难画了……)
(2)之字形旋转:
这是另外一种双旋转,当p和p的父亲不同为左儿子或者右儿子的时候,直接对p进行两次单旋就好了
综上我们可以发现双旋转的特殊情况仅有一字形那一种,接下来Splay核心操作代码奉上~
1 void splay(node *p) 2 { 3 while(p->p && p->p->p) 4 p->getlr() == p->p->getlr() ? (rot(p->p), rot(p)) : (rot(p), rot(p)); 5 if(p->p)rot(p); 6 p->update(); 7 }
这里定义splay(node *p)意为把p旋转到根。while循环中是两种双旋转,如果没有通过双旋转转到根,则最后进行一次单旋转。由于在rot操作中我们没有更新p,在最后我们把它更新一下
接下来的操作都很基础很简单啦:大部分和一般的二叉查找树没有区别
4.插入:
直接像二叉查找树一样插入就好啦,记得在最后把插入的节点splay到根上
1 void insert(int x) 2 { 3 if(!root){root = new node(x); return;} 4 node *p = root, *p1; 5 while(p){p1 = p; p = p->s[p->key < x];} 6 p = new node(x); 7 p1->link(p1->key < x, p); 8 splay(p); 9 }
5.查找:
还是像二叉查找树一样查找,最后不要忘记splay到根上,然后返回找到的那个节点的指针,找不到则返回空指针
1 node *find(int x) 2 { 3 node *p = root; 4 while(p && p->key != x)p = p->s[p->key < x]; 5 if(p)splay(p); 6 return p; 7 }
6.寻找第k小:
这个难度不大,稍微需要一点分析——如果p的左子树的size+1恰好等于k,则p为所求,若左子树的size大于等于k,则在左子树中继续查找。若左子树的size+1小于k,则在右子树中继续查找,同时别忘了把k减去左子树size+1。最后返回第k小的指针,找不到则返回空指针。特别注意左子树不存在的情况
1 node *findKth(int k) 2 { 3 if(root->size < k)return 0; 4 node *p = root; 5 while(!(((p->s[0] ? p->s[0]->size : 0) < k) && ((p->s[0] ? p->s[0]->size : 0) + 1 >= k))) 6 if(!p->s[0]){k -= 1; p = p->s[1];} 7 else {if(p->s[0]->size >= k)p = p->s[0];else{k = k - p->s[0]->size - 1;p = p->s[1];}} 8 if(p)splay(p); 9 return p; 10 }
7.前驱:
prev()操作求的是比当前根小的最大的数的指针,没什么难度,主要在删除操作时会用到,别忘记splay到根
1 node *prev() 2 { 3 node *p = root->s[0]; 4 if(!p)return 0; 5 while(p->s[1])p = p->s[1]; 6 splay(p); 7 return p; 8 }
8.后继:
和前驱操作对应的,求比当前根大的最小的数,也是在删除操作会用到
1 node *succ() 2 { 3 node *p = root->s[1]; 4 if(!p)return 0; 5 while(p->s[0])p = p->s[0]; 6 splay(p); 7 return p; 8 }
9.splay2:
同样是为删除操作做的准备工作,和前面的splay唯一的区别在于将p旋转到某一个顶点的儿子位置而非根节点的位置上
1 void splay(node *p, node *tar) 2 { 3 while(p->p != tar && p->p->p != tar) 4 p->getlr() == p->p->getlr() ? (rot(p->p), rot(p)) : (rot(p), rot(p)); 5 if(p->p != tar)rot(p); 6 p->update(); 7 }
我们在调用的时候会保证tar是p的某一个祖先(一般情况下tar是根节点)。特别地,我们可以用splay(p, 0)来代替splay(p)
10.删除:
在前面做了那么多准备工作之后,终于开始进行删除啦。伸展树中只支持段删除(亦即删除区间[l,r]内的左右节点)。我们先找到l的前驱p和r的前驱q(此时保证q为根),然后将p splay到q的左儿子处,这是p的右儿子就是所有满足区间[l,r]的节点,直接删除即可。如果原树中没有l和r,我们直接把它insert进去就行(反正最后会被删掉)。记得特别处理前驱后继不存在的情况,最后需要进行update
1 void del(int l, int r) 2 { 3 if(!find(l))insert(l); 4 node *p = prev(); 5 if(!find(r))insert(r); 6 node *q = succ(); 7 if(!p && !q){root = 0; return;} 8 if(!p){root->s[0] = 0; root->update(); return;} 9 if(!q){splay(p, 0); root->s[1] = 0; root->update(); return;} 10 splay(p, q); 11 p->s[1] = 0; 12 p->update(); 13 q->update(); 14 }
小结:
到此为止,伸展树的全部基础操作已经讲解完了。我们可以在伸展树上维护很多其他的信息来达到某些效果(比如起到线段树的作用)。数据结构是固定的,但它的思想是灵活的。在实际应用中不应该拘泥于模板,而是应该大胆创新,突破现有的束缚
第二部分 应用:【NOI2004】郁闷的出纳员分析
题目大意:实现一个数据结构满足在序列中插入一个数k、将所有数加上k、将所有数减去k并删除所有小于min的数、查找第K大的数这四种操作。满足操作数m<=100000,所有的数<=200000
很显然这道题可以用一个裸的伸展树来实现,但是需要做一些小小的调整:
1.在每个节点上增加一个值num,表示当前节点重复出现的个数(因为二叉查找树默认为两两顶点之间是互异的,对于重复的数我们只能把它记在同一个顶点里),对应需要修改update()、findKth()两个操作
2.每次加减都是对所有数的操作,如果我们直接模拟这个操作一定会超时,应该开一个变量delta,表示所有数的变化量
3.题目中有一个地方没有描述清:如果一个人来了就走,不计在走的人数里
1 //date 20131201 2 #include <cstdio> 3 #include <cstring> 4 5 #define INF 1000000 6 7 int ans; 8 9 struct Splay 10 { 11 struct node 12 { 13 node *p, *s[2]; 14 int key, size, num; 15 node(){p = s[0] = s[1] = 0; size = num = 1; key = 0;} 16 node(int key) :key(key) {p = s[0] = s[1] = 0; size = num = 1;} 17 bool getlr(){return p->s[1] == this;} 18 node *link(int w, node *p){s[w] = p; if(p)p->p = this; return this;} 19 void update(){size = num + (s[0] ? s[0]->size : 0) + (s[1] ? s[1]->size : 0);} 20 } *root; 21 void rot(node *p) 22 { 23 node *q = p->p->p; 24 p->getlr() ? p->link(0, p->p->link(1, p->s[0])) : p->link(1, p->p->link(0, p->s[1])); 25 p->p->update(); 26 if(q)q->link(q->s[1] == p->p, p); else{p->p = 0; root = p;} 27 } 28 void splay(node *p, node *tar) 29 { 30 while(p->p != tar && p->p->p != tar) 31 p->getlr() == p->p->getlr() ? (rot(p->p), rot(p)) : (rot(p), rot(p)); 32 if(p->p != tar)rot(p); 33 p->update(); 34 } 35 void preset(){root = 0;} 36 node *find(int x) 37 { 38 node *p = root; 39 while(p && p->key != x)p = p->s[p->key < x]; 40 if(p)splay(p, 0); 41 return p; 42 } 43 void insert(int x) 44 { 45 if(!root){root = new node(x); return;} 46 if(find(x)){++root->num; root->update(); return; } 47 node *p = root, *p1; 48 while(p){p1 = p; p = p->s[p->key < x];} 49 p = new node(x); 50 p1->link(p1->key < x, p); 51 splay(p, 0); 52 } 53 node *findKth(int k) 54 { 55 if(root->size < k)return 0; 56 node *p = root; 57 while(!(((p->s[0] ? p->s[0]->size : 0) < k) && ((p->s[0] ? p->s[0]->size : 0) + p->num >= k))) 58 if(!p->s[0]){k -= p->num; p = p->s[1];} 59 else {if(p->s[0]->size >= k)p = p->s[0];else{k = k - p->s[0]->size - p->num;p = p->s[1];}} 60 if(p)splay(p, 0); 61 return p; 62 } 63 node *prev() 64 { 65 node *p = root->s[0]; 66 if(!p)return 0; 67 while(p->s[1])p = p->s[1]; 68 splay(p, 0); 69 return p; 70 } 71 node *succ() 72 { 73 node *p = root->s[1]; 74 if(!p)return 0; 75 while(p->s[0])p = p->s[0]; 76 splay(p, 0); 77 return p; 78 } 79 void del(int l, int r) 80 { 81 if(!find(l)){insert(l);--ans;} 82 node *p = prev(); 83 if(!find(r)){insert(r);--ans;} 84 node *q = succ(); 85 if(!p && !q){ans += root->size; preset(); return;} 86 if(!p){ans += root->s[0] ? root->s[0]->size : 0; root->s[0] = 0; root->update(); return;} 87 if(!q){splay(p, 0); ans += root->s[1] ? root->s[1]->size : 0; root->s[1] = 0; root->update(); return;} 88 splay(p, q); 89 if(p->s[1])ans += p->s[1]->size; 90 p->s[1] = 0; 91 p->update(); 92 q->update(); 93 } 94 }S; 95 96 int n, m; 97 char sign; int x; 98 99 int main() 100 { 101 scanf("%d%d\n", &n, &m); 102 int delta = 0; 103 S.preset(); 104 ans = 0; 105 for(int i = 1; i <= n; ++i) 106 { 107 scanf("%c %d\n", &sign, &x); 108 switch(sign) 109 { 110 case 'I': if(x >= m)S.insert(x - delta);else ++ans; break; 111 case 'A': delta += x; break; 112 case 'S': delta -= x; S.del(-INF, m - delta - 1); break; 113 case 'F': if(!S.root || x > S.root->size)printf("-1\n");else{S.findKth(S.root->size + 1 - x); printf("%d\n", S.root->key + delta);} 114 } 115 } 116 printf("%d\n", ans); 117 return 0; 118 }
参考文献:
[1]数据结构与算法分析(C++描述 第三版),【美】Mark Allen Weiss, 张怀勇等 译,人民邮电出版社,2007
[2]算法导论(第二版),【美】Thomas H. Cormen , Charles E. Leiserson, Ronald L. Rivest, Clifford Stein, 潘金贵等 译,机械工业出版社,2011
[3]ACM国际大学生程序设计竞赛:知识与入门,俞勇,清华大学出版社,2012