导航

红黑树的Python实现

Posted on 2014-10-06 15:27  Rohan  阅读(1115)  评论(0编辑  收藏  举报

想用红黑树,怎么搜都搜不到现成的Python实现。干脆自己写一个。

算法的结构按照Sedgewick的《算法(4th)》一书第三章写成,略有改动。

完整的API实现,也用了一些测试case,暂时没发现问题。

这玩意就是好用,谁用谁知道。

废话不多说直接上代码。

#注意事项: 重载RBT.Node.reduce(self,new_val)来实现 append()方法中对已存在主键对应值的自定义合并操作。默认为调用list.extend()方法。

  1 #!/usr/bin/env python
  2 #coding: gbk
  3 
  4 ########################################################################
  5 #Author: Feng Ruohang
  6 #Create: 2014/10/06 11:38
  7 #Digest: Provide a common data struct: Red Black Tree
  8 ########################################################################
  9 
 10 class RBT(object):
 11     class Node(object):
 12         '''
 13         Node used in RBTree
 14         '''
 15         def __init__(self,key,value=None,color=False,N=1):
 16             self.key = key
 17             self.val = value
 18             self.color = color   #False for Black, True for Red
 19             self.N = N           #Total numbers of nodes in this subtree
 20             self.left = None
 21             self.right = None
 22 
 23         def __cmp__(l,r):
 24             return cmp(l.key,r.key)
 25 
 26         def __eq__(l,r):
 27             return True if l.key == r.key else False
 28 
 29         def __add__(l,r):
 30             l.value + r.value
 31 
 32         def reduce(self,new_val):
 33             self.val.extend(new_val)
 34 
 35     def __init__(self):
 36         self.root = None
 37 
 38     #====================APIs====================#
 39     #=====Basic API
 40 
 41     def get(self,key):
 42         return  self.__get(self.root,key)
 43 
 44     def put(self,key,val):
 45         self.root = self.__put(self.root,key,val)
 46         self.root.color = False 
 47 
 48     def append(self,key,val):
 49         self.root = self.__append(self.root,key,val)
 50         self.root.color = False
 51 
 52     def delete(self,key):
 53         if not self.contains(key):
 54             raise LookupError('No such keys in rbtree. Fail to Delete')
 55         if self.__is_black(self.root.left) and self.__is_black(self.root.right):
 56             self.root.color = True
 57         self.root = self.__delete(self.root,key)
 58         if not self.is_empty(): 
 59             self.root.color = False
 60 
 61     def del_min(self):
 62         if self.is_empty(): 
 63             raise LookupError('Empty Red-Black Tree. Can\'t delete min')
 64         if self.__is_black(self.root.left) and self.__is_black(self.root.right):
 65             self.root.color = True
 66         self.root = self.__del_min(self.root)
 67         if not self.is_empty(): self.root.color = False
 68 
 69     def del_max(self):
 70         if self.is_empty():
 71             raise LookupError('Empty Red-Black Tree. Can\'t delete max')
 72         if self.__is_black(self.root.left) and self.__is_black(self.root.right):
 73             self.root.color = True
 74         self.root = self.__del_max(self.root)
 75         if not self.is_empty(): self.root.color = False 
 76 
 77     def size(self):
 78         return self.__size(self.root)
 79 
 80     def is_empty(self):
 81         return not self.root
 82 
 83     def contains(self,key):
 84         return bool(self.get(key))
 85 
 86     #=====Advance API
 87     def min(self):
 88         if self.is_empty():
 89             return None
 90         return self.__min(self.root).key
 91 
 92     def max(self):
 93         if self.is_empty():
 94             return None
 95         return self.__max(self.root).key
 96 
 97     def floor(self,key):
 98         x = self.__floor(self.root,key)
 99         if x:
100             return x.key,x.val
101         else:
102             return None,None
103 
104     def ceil(self,key):
105         x = self.__ceil(self.root,key)
106         if x:
107             return x.key,x.val
108         else:
109             return None,None
110 
111     def below(self,key):
112         index = self.index(key)
113         if not 0 <= index - 1 < self.size():
114             return None,None    #Return None if out of range
115         x = self.__select(self.root,index - 1)
116         return x.key,x.val
117 
118     def above(self,key):
119         index = self.index(key)
120         if self.contains(key):
121             if not 0 <= index + 1 < self.size():
122                 return None,None    #Return None if out of range
123             else:
124                 x = self.__select(self.root,index+1)
125                 return x.key,x.val
126         else:#if key is not in tree. then select(i) is what we need
127             if not 0 <= index < self.size():
128                 return None,None    #Return None if out of range
129             else:
130                 x = self.__select(self.root,index)
131                 return x.key,x.val
132 
133     def index(self,key):
134         return self.__index(self.root,key) 
135 
136     def keys(self):
137         '''Return All Keys in the tree '''
138         return self.range(self.min(),self.max())
139 
140     def range(self,lo,hi):
141         '''Take two keys. return keys between them'''
142         q = []
143         self.__range(self.root,q,lo,hi)
144         return q
145 
146     def select(self,index):
147         '''Given Index Return Corresponding key '''
148         if not 0 <= index < self.size():
149             return None
150         return self.__select(self.root,index).key
151 
152     def width(self,lo,hi):
153         '''Return the numbers of keys between lo and hi '''
154         if lo > hi: 
155             return 0
156         if self.contains(hi):
157             return self.index(hi) - self.index(lo) + 1
158         else:
159             return self.index(hi) - self.index(lo)
160 
161 
162     #===============Private Method===============#
163     #=====Basic
164     def __get(self,x,key):
165         while x:
166             tag = cmp(key,x.key)
167             if tag < 0 : x = x.left
168             elif tag > 0 :x = x.right
169             else: return x.val
170 
171     def __put(self,h,key,val):
172         if not h: 
173             return self.Node(key,val,True,1)
174         tag = cmp(key,h.key)
175         if tag < 0: 
176             h.left = self.__put(h.left,key,val)
177         elif tag > 0: 
178             h.right = self.__put(h.right,key,val)
179         else: 
180             h.val = val   #Update
181 
182         if self.__is_black(h.left) and self.__is_red(h.right):
183             h = self.__rotate_left(h)
184         if self.__is_red(h.left) and self.__is_red(h.left.left):
185             h = self.__rotate_right(h)
186         if self.__is_red(h.left) and self.__is_red(h.right):
187             self.__flip_colors(h)
188         h.N = self.__size(h.left) + self.__size(h.right) + 1
189         return h
190 
191     def __append(self,h,key,val):
192         if not h: 
193             return self.Node(key,val,True,1)
194         tag = cmp(key,h.key)
195         if tag < 0: 
196             h.left = self.__append(h.left,key,val)
197         elif tag > 0: 
198             h.right = self.__append(h.right,key,val)
199         else: 
200             h.reduce(val)   #append.
201 
202         if self.__is_black(h.left) and self.__is_red(h.right):
203             h = self.__rotate_left(h)
204         if self.__is_red(h.left) and self.__is_red(h.left.left):
205             h = self.__rotate_right(h)
206         if self.__is_red(h.left) and self.__is_red(h.right):
207             self.__flip_colors(h)
208         h.N = self.__size(h.left) + self.__size(h.right) + 1
209         return h
210 
211     def __del_min(self,h):
212         if not h.left: #if h is empty:return None
213             return None
214 
215         if self.__is_black(h.left) and self.__is_black(h.left.left):
216             self.__move_red_left(h)
217         h.left = self.__del_min(h.left) #Del recursive
218         return self.__balance(h)
219 
220     def __del_max(self,h):
221         if self.__is_red(h.left): 
222             h = self.__rotate_right(h)
223         if not h.right: 
224             return None
225         if self.__is_black(h.right) and self.__is_black(h.right.left):
226             h = self.__move_red_right(h)
227         h.right = self.__del_max(h.right)
228         return self.__balance(h)
229 
230     def __delete(self,h,key):
231         if key < h.key:
232             if self.__is_black(h.left) and self.__is_black(h.left.left):
233                 h = self.__move_red_left(h)
234             h.left = self.__delete(h.left,key)
235         else:
236             if self.__is_red(h.left):
237                 h = self.__rotate_right(h)
238             if key == h.key and not h.right:
239                 return None
240             if self.__is_black(h.right) and self.__is_black(h.right.left):
241                 h = self.__move_red_right(h)
242             if key == h.key:#replace h with min of right subtree
243                 x = self.__min(h.right)
244                 h.key = x.key
245                 h.val = x.val
246                 h.right = self.__del_min(h.right)
247             else:
248                 h.right = self.__delete(h.right,key)
249         h = self.__balance(h)
250         return h
251     
252     #=====Advance
253     def __min(self,h):
254         #Assume h is not null
255         if not h.left:
256             return h
257         else:
258             return self.__min(h.left)
259 
260     def __max(self,h):
261         #Assume h is not null
262         if not h.right:
263             return h
264         else:
265             return self.__max(h.right)
266 
267     def __floor(self,h,key):
268         '''Find the NODE with key <= given key in the tree rooted at h '''
269         if not h:
270             return None
271         tag = cmp(key,h.key)
272         if tag == 0:
273             return h
274         if tag < 0:
275             return self.__floor(h.left,key)
276         t = self.__floor(h.right,key)
277         if t:#if find in right tree
278             return t
279         else:#else return itself
280             return h
281 
282     def __ceil(self,h,key):
283         '''Find the NODE with key >= given key in the tree rooted at h '''
284         if not h:
285             return None
286         tag = cmp(key,h.key)
287         if tag == 0:
288             return h
289         if tag > 0: # key is bigger
290             return self.__ceil(h.right,key)
291         t = self.__ceil(h.left,key)#key is lower.Try to find ceil left
292         if t:#if find in left tree
293             return t
294         else:#else return itself
295             return h
296 
297     def __index(self,h,key):
298         if not h:
299             return 0
300         tag = cmp(key,h.key)
301         if tag < 0:
302             return self.__index(h.left,key)
303         elif tag > 0:   #Key is bigger
304             return self.__index(h.right,key) + 1 + self.__size(h.left)
305         else:   #Eq
306             return self.__size(h.left)
307 
308     def __select(self,h,index):
309         '''assert h. assert 0 <= index < size(tree) '''
310         l_size = self.__size(h.left)
311         if l_size > index:
312             return self.__select(h.left,index)
313         elif l_size < index:
314             return self.__select(h.right,index - l_size - 1)
315         else:
316             return h
317 
318     def __range(self,h,q,lo,hi):
319         if not h: 
320             return
321         tag_lo = cmp(lo,h.key)
322         tag_hi = cmp(hi,h.key)
323         if tag_lo < 0:#lo key is lower than h.key
324             self.__range(h.left,q,lo,hi)
325         if tag_lo <= 0 and tag_hi >= 0:
326             q.append(h.key)
327         if tag_hi > 0 :# hi key is bigger than h.key
328             self.__range(h.right,q,lo,hi)
329 
330 
331     #===============Adjust Functions=============#
332     def __rotate_right(self,h):
333         x = h.left
334         h.left,x.right = x.right,h
335         x.color,x.N = h.color,h.N
336         h.color,h.N = True,self.__size(h.left) + self.__size(h.right) + 1
337         return x
338 
339     def __rotate_left(self,h):
340         x = h.right
341         h.right,x.left = x.left,h
342         x.color,x.N = h.color,h.N
343         h.color,h.N = True,self.__size(h.left) + self.__size(h.right) + 1
344         return x
345 
346     def __flip_colors(self,h):
347         h.color = not h.color
348         h.left.color = not h.left.color
349         h.right.color = not h.right.color
350 
351     def __move_red_left(self,h):
352         self.__flip_colors(h)
353         if self.__is_red(h.right.left):
354             h = self.__rotate_left(h)
355         return h
356 
357     def __move_red_right(self,h):
358         self.__flip_colors(h)
359         if self.__is_red(h.left.left):
360             h = self.__rotate_right(h)
361         return h
362 
363     def __balance(self,h):
364         if self.__is_red(h.right): 
365             h = self.__rotate_left(h)
366         if self.__is_red(h.left) and self.__is_red(h.left.left): 
367             h = self.__rotate_right(h)
368         if self.__is_red(h.left) and self.__is_red(h.right):
369             self.__flip_colors(h)
370         h.N = self.__size(h.left) + self.__size(h.right) + 1
371         return h
372     
373     #Class Method
374     @staticmethod
375     def __is_red(x):
376         return False if not x else x.color
377 
378     @staticmethod
379     def __is_black(x):
380         return True  if not x else not x.color
381 
382     @staticmethod
383     def __size(x):
384         return 0 if not x else x.N
385 
386 
387 def RBT_testing():
388     '''API Examples '''
389     t = RBT()
390     test_data = "SEARCHXMPL"
391 
392     print '=====testing is_empty()\nBefore Insertion'
393     print t.is_empty()
394 
395     for letter in test_data:
396         t.put(letter,[ord(letter)])
397         print "Test Inserting:%s, tree size is %d" % (letter,t.size())
398     print "After insertion it return:"
399     print t.is_empty()
400     print  "====test is_empty complete\n"
401 
402 
403     print "=====Tesing Get method:"
404     print "get 's' is "
405     print t.get('S')
406     print "get 'H' is "
407     print t.get('H')
408 
409     print '==Trying get null key: get "F" is'
410     print t.get('F')
411     print "=====Testing Get method end\n\n"
412 
413     print "=====Testing ceil and floor"
414     print "Ceil('L')"
415     print t.ceil('L')
416     print "Ceil('F') *F is not in tree"
417     print t.ceil('F')
418 
419     print "Floor('L')"
420     print t.ceil('L')
421     print "Floor('F')"
422     print t.ceil('F')
423 
424     print '======test append method'
425     print 'Orient key e is correspond with'
426     print t.get('E')
427     t.append('E',[4])
428     print '==After append'
429     print t.get('E')
430     print "=====Testing Append method end\n\n" 
431 
432     print "=====Testing index()"
433     print "index(E)" 
434     print t.index('E')
435     print "index(L),select(4)"
436     print t.index('L'),t.select(4)
437     print "index('M'),select(5)"
438     print t.index('M'),t.select(5)
439     print "index a key not in tree:\n index('N'),select(6)"
440     print t.index('N'),t.select(6)
441     print "index('P')"
442     print t.index('P')
443     
444     
445     print "=====Testing select"
446     print "select(3) = "
447     print t.select(3)
448     print "select and index end...\n\n"
449 
450     print "====Tesing Min and Max"
451     print "min key is:"
452     print t.min()
453     print "max key is"
454     print t.max()
455 
456     print "==How much between min and max:"
457     print t.width(t.min(),t.max())
458     print "keys between min and max:"
459     print t.keys()
460     print "keys in 'E' and 'M' "
461     print t.range('E','M')
462 
463 
464     print "try to delete min_key:"
465     print "But we could try contains('A') first"
466     print t.contains('A')
467     t.del_min()
468     print "After deletion t.contains('A') is "
469     print t.contains('A')
470 
471     print t.min()
472     print "try to kill one more min key:"
473     t.del_min()
474     print t.min()
475     print "try to delete max_key,New Max key is :"
476     t.del_max()
477     print t.max()
478     print "=====Tesing Min and Max complete\n\n"
479 
480 
481 
482     print '=====Deleting Test'
483     print t.size()
484     t.delete('H')
485     print t.size()
486 
487     print 'Delete a non-exists key:'
488     try:    
489         t.delete('F')
490     except:
491         print "*Look up error occur*"
492 
493     print "=====Testing Delete method complete"
494 
495 def test_basic_api():
496     print "==========Testing Basic API=========="
497     t = RBT()
498     print "Test Data: FENGDIJKABCLM"
499     test_data = "FENGDIJKABCLM" #from A-N,without H
500 
501     #=====put()
502     print "==========put() test begin!=========="
503     for letter in test_data:
504         t.put(letter,[ord(letter)]) #Value is [ascii order of letter]
505         print "put(%s); Now tree size is %d"%(letter,t.size())
506     print 'Final tree size is %d'%t.size()
507     print "==========put() test complete!==========\n"
508 
509     #=====get()
510     print "==========get() test begin!=========="
511     print "get('F'):\t%s"%repr(t.get('F'))
512     print "get('A'):\t%s"%repr(t.get('A'))
513     print "get a non-exist key Z: get('Z'):\t%s"%repr(t.get('Z'))
514     print "==========get() test complete!==========\n"
515 
516     #=====append()
517     print "=====append() test begin!=========="
518     print "First append to a exist key:[F]"
519     print "Before Append:get('F'):\t%s"%repr(t.get('F'))
520     print "append('F',[3,'haha']):\t%s"%repr(t.append('F',[3,'haha']))
521     print "After Append:get('F'):\t%s\n"%repr(t.get('F'))
522     print "Second append to a non-exist key:[O]"
523     print "Before Append:get('O'):\t%s"%repr(t.get('O'))
524     print "append a non-exist key O: append('O',['value of O']):\t%s"%repr(t.append('O',['value of O']))
525     print "After Append:get('O'):\t%s\n"%repr(t.get('O'))
526     print "==========append() test complete!==========\n"
527 
528     #=====delete()
529     print "==========delete() test begin!=========="
530     test_data2 = [x for x in test_data]
531     test_data2.reverse()
532     for letter in test_data2:
533         t.delete(letter)
534         print "delete(%s); Now tree size is %d"%(letter,t.size())
535     print 'Final tree size is %d'%t.size()
536     print "==========delete() test complete!==========\n"
537 
538     print "==========Basic API Test Complete==========\n\n"
539 
540 def test_advance_api():
541     print "==========Testing min max floor ceil above below =========="
542     t = RBT()
543     print "Test Data: FENGDIJKABCLM"
544     test_data = "FENGDIJKABCLM" #from A-N,without H
545     for letter in test_data:
546         t.put(letter,[ord(letter)]) #Value is [ascii order of letter]
547 
548     #=====min() and del_min()
549     print "==========min() and del_min() test begin!=========="
550     print "Original min():\t%s"%repr(t.min())
551     print "run del_min()"
552     t.del_min()
553     print "After del_min:min()\t%s"%repr(t.min())
554 
555     print "run del_min() again"
556     t.del_min()
557     print "After del_min run again:min()\t%s"%repr(t.min())
558 
559     print "=====max() and del_max() test begin!"
560     print "Original max():\t%s"%repr(t.max())
561     print "run del_max()"
562     t.del_max()
563     print "After del_max:max()\t%s"%repr(t.max())
564 
565     print "run del_max() again"
566     t.del_max()
567     print "After del_max run again:max()\t%s"%repr(t.max())
568     print "==========min() max() del_min() del_max() test complete!==========\n"
569 
570 def test_int_api():
571     #======ceil floor above below
572     print "==========Testing ceil floor above below =========="
573     t = RBT()
574     print "Test Data: FENGDIJKABCLM - [AHN] = FEGDIJKBCLM"
575     test_data = "FEGDIJKBCLM" #from A-N, Del A H N
576 
577     for letter in test_data:
578         t.put(letter,[ord(letter)]) #Value is [ascii order of letter]
579     print "Node\tceil\t\tfloor\t\tabove\t\tbelow"
580     for P in ['A','B','C','G','H','I','L','M','N']:
581         print "%s\t%s\t%s\t%s\t%s"%(P,t.ceil(P),t.floor(P),t.above(P),t.below(P))
582 
583 if __name__ == '__main__':
584     test_basic_api()
585     test_advance_api()
586     test_int_api()
View Code

查找操作的数据结构不断进化,才有了红黑树:从链表到二叉平衡树,再到2-3树,最后到红黑树。

红黑树本质上是用二叉平衡树的形式来模拟2-3树的功能

《算法导论》也好,其他什么乱七八糟算法书博客也罢,讲红黑树都没讲到本质。

Sedgewick的《算法(4th)》这本书就很不错:起码他告诉你红黑树是怎么来的。 

 仔细理解2-3树与红黑树的相同之处,才能对那些乱七八糟的插入删除调整操作有直观的认识。