AVL树C++实现

说来惭愧,工作三年了一直都没有真正弄懂AVL树的原理。因为最近在看STL源码,但STL的map和set的底层数据结构是红黑树,而红黑树是不严格的AVL树,所以理解红黑树之前必须要先弄懂AVL树。借此契机,将AVL树从原理和代码层面拿下。

1. AVL树简介

AVL树种的任意节点的左右子树的高度差的绝对值最大为1,其本质是带了平衡功能的二叉搜索树

二叉搜索树在数据极端情况下会退化成单链表,时间复杂度也会退化成O(n)。而AVL树定义了旋转操作,在平衡因子大于2时,通过旋转来调整树的结构,来重新满足平衡因子小于2,确保在查找、插入和删除在平均和最坏情况下都是O(logn)。

2. AVL旋转

AVL旋转是AVL树最核心的部分,需要重点掌握。在理解AVL旋转之前先知道以下几个概念:

  • AVL树节点的插入总是在叶子节点;
  • AVL树在插入节点之前是满足平衡条件的;
  • 插入新节点后有可能满足平衡条件也可能不满足;
  • 当不满足平衡条件时需要对新的树进行旋转。

旋转之前首先需要找到插入节点向上第一个不平衡的节点(记为A),新插入节点只能在A的的左子树的左子树、左子树的右子树、右子树的左子树、右子树的右子树上,对应四种不同的旋转方式。

#ifndef AVL_TREE_H
#define AVL_TREE_H

#include <algorithm>

template <typename T>
class avltree {
public:
    struct Node {
        Node(T x) 
            : val(x), left(nullptr), right(nullptr) {}
        Node(const Node* n) 
            : val(n->val), left(n->left), right(n->right) {}
        T val;
        Node* left;
        Node* right;
    };
public:
    avltree() : root(nullptr) {}
    ~avltree()
    {
        destroy(root);
    }
    void insert(const T& val)
    {
        root = insert(root, val);
    }
    void remove(const T& val)
    {
        root = remove(root, val);
    }
    Node* get_root()
    {
        return root;
    }
    static int balance_fector(Node* node) 
    {
        if (node == nullptr) 
            return 0;
        return height(node->left) - height(node->right);
    }

private:
    void destroy(Node* node)
    {
        if (node != nullptr)
        {
            destroy(node->left);
            destroy(node->right);
            delete node;
        }
    }
    Node* insert(Node* node, const T& val)
    {
        if (node == nullptr)
            return new Node(val);
        
        if (val == node->val) 
            return node;
        if (val < node->val)
            node->left = insert(node->left, val);
        else
            node->right = insert(node->right, val);

        return rebalance(node);
    }
    Node* remove(Node* node, const T& val)
    {
        if (node == nullptr) return node;
        if (val < node->val)
        {
            node->left = remove(node->left, val);
        }
        else if (val > node->val)
        {
            node->right = remove(node->right, val);
        } 
        else 
        {
            if (node->left == nullptr)
            {
                Node* del = node;
                node = node->right;
                delete del;
            } 
            else if (node->right == nullptr)
            {
                Node* del = node;
                node = node->left;
                delete del;
            }
            else
            {
                Node* successor = new Node(minimum(node->right));
                node->right = remove(node->right, successor->val);
                successor->left = node->left;
                successor->right = node->right;
                delete node;
                node = successor;
            }
        }
        return rebalance(node);
    }
    Node* minimum(Node* node) 
    {
        while (node->left)
        {
            node = node->left;
        } 
        return node;
    }

    Node* rebalance(Node* node)
    {
        int fector = balance_fector(node);
        if (fector == 2)
        {
            if (balance_fector(node->left) > 0)
                node = rightRotate(node);
            else
                node = leftRightRotate(node);
        }
        if (fector == -2) 
        {
            if (balance_fector(node->right) < 0)
                node = leftRotate(node);
            else
                node = rightLeftRotate(node);
        }
        return node;
    }
    static int height(Node* node) 
    {
        if (node == nullptr) 
            return 0;
        return std::max(height(node->left), height(node->right)) + 1;
    }
    Node* rightRotate(Node* node)  //LL
    {
        Node* left = node->left;
        node->left = left->right;
        left->right = node;
        return left;
    }
    Node* leftRotate(Node* node)  //RR
    {
        Node* right = node->right;
        node->right = right->left;
        right->left = node;
        return right;
    }
    Node* leftRightRotate(Node* node) //LR
    {
        node->left = leftRotate(node->left);
        return rightRotate(node);
    }
    Node* rightLeftRotate(Node* node)  //RL
    {
        node->right = rightRotate(node->right);
        return leftRotate(node);
    }

private:
    Node* root;
};

#endif

 

posted @ 2019-07-20 17:30  evenleo  阅读(955)  评论(3编辑  收藏  举报