class Tree(object):
    '''
    树节点
    '''
    def __init__(self, data, parent, left=None, right=None, factor=0):
        self.data = data
        self.parent = parent
        self.left = left
        self.right = right
        self.factor = factor
        
class AVLTree(object):
    def __init__(self):
        self.root = None
        
    
    def insertNode(self, key):
        '''
        找到要插入的元素的位置,返回保存记录从根节点到插入节点的位置
        '''
        if self.root is None:
            self.root = Tree(key, None)
            return [self.root]
        traverse_path = []
        node = self.root
        while(node is not None):
            traverse_path.append(node)
            if key > node.data:
                if node.right is None:
                    node.right = Tree(key, node)
                    return traverse_path
                else:
                    node = node.right
            elif key < node.data:
                if node.left is None:
                    node.left = Tree(key, node)
                    return traverse_path
                else:
                    node=node.left
            else:
                print("%d is in BinaryTree" % key)
                return traverse_path
    
    def searchKeyPath(self, key):
        '''
        从根节点出发查找关键字key经过的path
        '''
        if self.root is None:
            print('root is empty')
            return 
        parent_path = []
        node = self.root
        while node is not None:
            if key < node.data:
                parent_path.append(node)
                node = node.left
            elif key > node.data:
                parent_path.append(node)
                node = node.right
            else:
                return parent_path
            
    
    def calculate_depth(self, node):
        '''
        计算节点的高度
        '''
        if (node.left is None) and (node.right is None):
            return 0
        elif node.left is None:
            height = self.calculate_depth(node.right) + 1
#             print("height right is %d"%height)
        elif node.right is None:
            height = self.calculate_depth(node.left) + 1
#             print("height left is %d"%height)
        else:
            height = max(self.calculate_depth(node.left) , self.calculate_depth(node.right)) + 1
#             print("height is %d"%height)
        
        return height
            
    def calculate_factor(self, node):
        '''
        计算节点的平衡因子
        '''
        if node.left is None and node.right is None:
            return 0
        elif node.left is None:
            return 0 - (self.calculate_depth(node.right)+1)
        elif node.right is None:
            return (self.calculate_depth(node.left)+1) - 0
        else:
            return (self.calculate_depth(node.left)+1) - (self.calculate_depth(node.right)+1)


    '''
    维持平衡因子,与网上的LL,LR,RR,RL旋转不同,这里分为两类左旋和右旋,情况差不多吧,但是树节点的摆放不一样,前/中/后序遍历是一样的
    '''
    def rotate_left(self, node):
        '''
        找到node左子树最大的节点
        '''
        # node节点是打破平衡因子的节点,需要调整
        rotate_left_path = []
        parent = node
        rotate_left_path.append(parent)
        temp = node.left
        ## 类似于LR
        if temp.right:
            parent = temp
            rotate_left_path.append(parent)
            temp = temp.right
            while temp.right:
                parent = temp
                rotate_left_path.append(parent)
                temp = temp.right
            # node 有parent
            if node.parent:
                # 设置node的parent指向
                if node.parent.data < temp.data:
                    node.parent.right = temp
                else:
                    node.parent.left = temp
                # 设置parent
                temp.parent = node.parent
                node.parent = temp
            else:
                node.parent = temp
                temp.parent = None
                self.root = temp
            # 进行右旋转    
            parent.right = temp.left
            if temp.left:
                temp.left.parent = parent
            temp.left = node.left
            node.left.parent = temp
            temp.right = node
            node.left = None
            
        # node的孩子节点没有右子树 ,类似LL
        else:
            # 如果node有父节点
            if node.parent:
                # 父节点指向新的节点即temp
                if node.parent.data < temp.data:
                    node.parent.right = temp
                else:
                    node.parent.left = temp
                    
                # 设置node 和 temp的父节点
                temp.parent = node.parent
                node.parent = temp
            # 如果node没有父节点,即node是根节点
            else:
                node.parent = temp
                temp.parent = None
                self.root = temp
                
            temp.right = node
            node.left = None
        
        # 从根节点出发找到temp节点的路径
#         rotate_left_path = self.searchKeyPath(node.data)
        rotate_left_path.extend([node])
        print('rotate node parent path ')
        for element in rotate_left_path:
            print(element.data, end=' ')
        rotate_left_path = rotate_left_path[::-1]
        for element in rotate_left_path:
            print(element.data, end=' ')
        print('end rotate node parent path')
        return rotate_left_path
        
    
    def rotate_right(self, node):
        '''
        找到node右子树最小的节点
        '''
        rotate_right_path = []
        parent = node
        rotate_right_path.append(parent)
        temp = node.right
        if temp.left:
            parent = temp
            rotate_right_path.append(parent)
            temp = temp.left
            while temp.left:
                parent = temp
                rotate_right_path.append(parent)
                temp = temp.left
            if node.parent:
                if node.parent.data > temp.data:
                    node.parent.left = temp
                else:
                    node.parent.right = temp
                temp.parent = node.parent
                node.parent = temp
            else:
                node.parent = temp
                temp.parent = None
                self.root = temp
            # 对右子树进行左旋转
            temp.left = node
            parent.left = temp.right
            if temp.right:
                temp.right.parent = parent
            temp.right = node.right
            node.right.parent = temp
            node.right = None
        else:
            # 对右子树进行左旋转
            if node.parent:
                if node.parent.data > temp.data:
                    node.parent.left = temp
                else:
                    node.parent.right = temp
                temp.parent = node.parent
                node.parent = temp
                temp.left = node
                node.right = None
            # 果然node节点没有父节点
            else:
                # 设置parent
                node.parent = temp
                temp.parent = None
                # 设置child
                self.root = temp
                temp.left = node
                node.right = None
        
        # 从根节点出发找到temp节点的路径
#         rotate_right_path = self.searchKeyPath(node.data)
        rotate_right_path.extend([node])
        print('rotate node parent path ')
        for element in rotate_right_path:
            print(element.data, end=' ')
        rotate_right_path = rotate_right_path[::-1]
        for element in rotate_right_path:
            print(element.data, end=' ')
        print('end rotate node parent path')
        return rotate_right_path
        
            
                
    
    
    def beforeAndFactorIterator(self):
        ''' before traversing
        '''
        node = self.root
        stack = []
        while(node or stack):
            while node:
                yield (node.data, node.factor)
                stack.append(node)
                node = node.left
            node = stack.pop()
            node = node.right
            
    def beforeIterator(self):
        ''' before traversing
        '''
        node = self.root
        stack = []
        while(node or stack):
            while node:
                yield node.data
                stack.append(node)
                node = node.left
            node = stack.pop()
            node = node.right
            
    def __iter__(self):
        ''' inorder traversing
        '''
        node = self.root
        stack = []
        while (node or stack):
            while node :
                stack.append(node)
                node = node.left
            node = stack.pop()
            yield node.data
            node = node.right
            

    
    def postIterator(self):
        ''' postorder traversing
        '''
        node = self.root
        stack = []
        while node or stack:
            while node and node.traversed_right == 0 :
                node.traversed_right = 1 # 表示已经入list,list中的节点不能再向左访问
                stack.append(node)
                node = node.left
            
            node = stack.pop()
            if node.right and node.traversed_right != 2:
                node.traversed_right = 2
                stack.append(node)
                node = node.right
            else:
                yield node.data
                if len(stack) == 0:
                    break  
  
  
if __name__ == '__main__':
#     lis = [62, 58, 88, 48, 73, 99, 35, 51, 93, 29, 37, 49, 56, 36, 50]
#     lis = [20, 10, 30, 15, 5, 1]
    
    # test accuracy
    lis = []
    for element in range(1,100,1):
        lis.append(element)
    print(lis)    
    avl_tree = AVLTree()
    for i in range(len(lis)):
        traverse_path = avl_tree.insertNode(lis[i])
        # 倒序遍历
        while traverse_path :
            node = traverse_path.pop()
            factor = avl_tree.calculate_factor(node)
            if factor >= 2 :
                for i in avl_tree.beforeIterator():
                    print(i, end=" ")
                    
                print('end')    
                print("rotate_left %d" % node.data)
                rotate_left_path = avl_tree.rotate_left(node)
                factor = avl_tree.calculate_factor(node)
                for element in rotate_left_path:
                    if element not in traverse_path:
                        traverse_path.append(element)
            elif factor <= -2 :
                for i in avl_tree.beforeIterator():
                    print(i, end=" ")
                    
                print('before rotate_right')    
                print("rotate_right %d" % node.data)
                rotate_right_path = avl_tree.rotate_right(node)
                factor = avl_tree.calculate_factor(node)
                for element in rotate_right_path:
                    if element not in traverse_path:
                        traverse_path.append(element)
            
            node.factor = factor
            for i, factor in avl_tree.beforeAndFactorIterator():
                print("d=%d f=%d"%(i, factor), end=", ")
            print("end for")
            
            
    for node in avl_tree:
        print(node, end=' ')
    print('end\n before traversing')
    
    for i, factor in avl_tree.beforeAndFactorIterator():
        print("d=%d f=%d"%(i, factor), end=", ")
    print('end')

 

LL/LR/RR/RL旋转方式见链接

 

Reference

[1] https://www.cnblogs.com/sfencs-hcy/p/10356467.html