AVL树C++实现
说来惭愧,工作三年了一直都没有真正弄懂AVL树的原理。因为最近在看STL源码,但STL的map和set的底层数据结构是红黑树,而红黑树是不严格的AVL树,所以理解红黑树之前必须要先弄懂AVL树。借此契机,将AVL树从原理和代码层面拿下。
1. AVL树简介
AVL树种的任意节点的左右子树的高度差的绝对值最大为1,其本质是带了平衡功能的二叉搜索树。
二叉搜索树在数据极端情况下会退化成单链表,时间复杂度也会退化成O(n)。而AVL树定义了旋转操作,在平衡因子大于2时,通过旋转来调整树的结构,来重新满足平衡因子小于2,确保在查找、插入和删除在平均和最坏情况下都是O(logn)。
2. AVL旋转
AVL旋转是AVL树最核心的部分,需要重点掌握。在理解AVL旋转之前先知道以下几个概念:
- AVL树节点的插入总是在叶子节点;
- AVL树在插入节点之前是满足平衡条件的;
- 插入新节点后有可能满足平衡条件也可能不满足;
- 当不满足平衡条件时需要对新的树进行旋转。
旋转之前首先需要找到插入节点向上第一个不平衡的节点(记为A),新插入节点只能在A的的左子树的左子树、左子树的右子树、右子树的左子树、右子树的右子树上,对应四种不同的旋转方式。
#ifndef AVL_TREE_H #define AVL_TREE_H #include <algorithm> template <typename T> class avltree { public: struct Node { Node(T x) : val(x), left(nullptr), right(nullptr) {} Node(const Node* n) : val(n->val), left(n->left), right(n->right) {} T val; Node* left; Node* right; }; public: avltree() : root(nullptr) {} ~avltree() { destroy(root); } void insert(const T& val) { root = insert(root, val); } void remove(const T& val) { root = remove(root, val); } Node* get_root() { return root; } static int balance_fector(Node* node) { if (node == nullptr) return 0; return height(node->left) - height(node->right); } private: void destroy(Node* node) { if (node != nullptr) { destroy(node->left); destroy(node->right); delete node; } } Node* insert(Node* node, const T& val) { if (node == nullptr) return new Node(val); if (val == node->val) return node; if (val < node->val) node->left = insert(node->left, val); else node->right = insert(node->right, val); return rebalance(node); } Node* remove(Node* node, const T& val) { if (node == nullptr) return node; if (val < node->val) { node->left = remove(node->left, val); } else if (val > node->val) { node->right = remove(node->right, val); } else { if (node->left == nullptr) { Node* del = node; node = node->right; delete del; } else if (node->right == nullptr) { Node* del = node; node = node->left; delete del; } else { Node* successor = new Node(minimum(node->right)); node->right = remove(node->right, successor->val); successor->left = node->left; successor->right = node->right; delete node; node = successor; } } return rebalance(node); } Node* minimum(Node* node) { while (node->left) { node = node->left; } return node; } Node* rebalance(Node* node) { int fector = balance_fector(node); if (fector == 2) { if (balance_fector(node->left) > 0) node = rightRotate(node); else node = leftRightRotate(node); } if (fector == -2) { if (balance_fector(node->right) < 0) node = leftRotate(node); else node = rightLeftRotate(node); } return node; } static int height(Node* node) { if (node == nullptr) return 0; return std::max(height(node->left), height(node->right)) + 1; } Node* rightRotate(Node* node) //LL { Node* left = node->left; node->left = left->right; left->right = node; return left; } Node* leftRotate(Node* node) //RR { Node* right = node->right; node->right = right->left; right->left = node; return right; } Node* leftRightRotate(Node* node) //LR { node->left = leftRotate(node->left); return rightRotate(node); } Node* rightLeftRotate(Node* node) //RL { node->right = rightRotate(node->right); return leftRotate(node); } private: Node* root; }; #endif