AVL的C++实现

网上已经有了一些AVL的c++实现,我也学习了几篇,比如这篇我参考的博客https://www.cnblogs.com/maybe2030/p/4732377.html,这篇文章关于AVL的思想的描述没什么问题,但是实现是存在一些bug的,主要出现在树高度的维护和旋转部分。我读了这位作者的实现然后自己撸了一份,在验证过程中发现我的代码通过了测试,但是参考的博客的代码没有通过。测试的办法是我随机乱序生成了0-1023共1024个数,逐次插入这些数到AVL树中,每次插入后我都调用isAVL()来判断树是否保证平衡了,插入结束后逐个删除,同样随后调用isAVL()判断删除是否导致树失衡了。这份代码旨在理清算法,销毁树的内存回收函数和查找函数我都没写,欢迎大家测试代码正确性,如果这份代码有bug,留言我会继续调一下(有bug的概率我觉得很低,毕竟这个测试还是很难的)。

//avltree.h
#ifndef AVLTREE_H_INCLUDED
#define AVLTREE_H_INCLUDED
#include <vector>

using std::vector;
using std::max;

template<class T>
class BNode
{
public:
    BNode *left,*right;
    int height;
    T val;
    BNode():left(NULL),right(NULL),height(0),val(0) {}
    BNode(T _val):left(NULL),right(NULL),height(1),val(_val) {}
};

template<class T>
class AVLTree
{
private:
    BNode<T> *root;
    void traversal(BNode<T> *rt,vector<T>&vec);
    BNode<T>* AVL_insert(BNode<T> *rt,T x);
    BNode<T>* AVL_delete(BNode<T> *rt,T x);
    int get_height(BNode<T> *rt);
    BNode<T>* rotate_LL(BNode<T> *rt);
    BNode<T>* rotate_LR(BNode<T> *rt);
    BNode<T>* rotate_RL(BNode<T> *rt);
    BNode<T>* rotate_RR(BNode<T> *rt);
    // just for testing
    bool isAVL(BNode<T> *rt);
public:
    AVLTree<T>():root(NULL) {}
    //此处应该实现销毁内存操作
    ~AVLTree<T>() {}
    vector<T> traversal();
    void insert(T x);
    void erase(T x);
    bool isAVL();

};

template<class T>
bool AVLTree<T>::isAVL(BNode<T> *rt)
{
    if(!rt) return true;
    int left_height = rt->left?rt->left->height:0;
    int right_height = rt->right?rt->right->height:0;
    if(abs(left_height - right_height) <= 1 && isAVL(rt->left) && isAVL(rt->right)) return true;
    return false;
}

template<class T>
bool AVLTree<T>::isAVL()
{
    return isAVL(root);
}

template<class T>
BNode<T>* AVLTree<T>::rotate_LL(BNode<T> *rt)
{
    BNode<T>* lson = rt->left;
    rt->left = lson->right;
    lson->right = rt;
    rt->height = max(get_height(rt->left),get_height(rt->right)) + 1;
    lson->height = max(get_height(lson->left),get_height(lson->right)) + 1;
    return lson;
}

template<class T>
BNode<T>* AVLTree<T>::rotate_RR(BNode<T> *rt)
{
    BNode<T>* rson = rt->right;
    rt->right = rson->left;
    rson->left = rt;
    rt->height = max(get_height(rt->left),get_height(rt->right)) + 1;
    rson->height = max(get_height(rson->left),get_height(rson->right)) + 1;
    return rson;
}

template<class T>
BNode<T>* AVLTree<T>::rotate_LR(BNode<T> *rt)
{
    rt->left = rotate_RR(rt->left);
    return rotate_LL(rt);
}

template<class T>
BNode<T>* AVLTree<T>::rotate_RL(BNode<T> *rt)
{
    rt->right = rotate_LL(rt->right);
    return rotate_RR(rt);
}

template<class T>
int AVLTree<T>::get_height(BNode<T> *rt)
{
    if(!rt) return 0;
    else return rt->height;
}

template<class T>
void AVLTree<T>::insert(T x)
{
    root = AVL_insert(root,x);
}

template<class T>
void AVLTree<T>::erase(T x)
{
    root = AVL_delete(root,x);
}

template<class T>
BNode<T> *AVLTree<T>::AVL_insert(BNode<T> *rt,T x)
{
    if(!rt)
    {
        BNode<T> *node = new BNode<T>(x);
        if(!root) root = node;
        return node;
    }
    if(x == rt->val) return rt;
    if(x < rt->val)
    {
        rt->left = AVL_insert(rt->left,x);
        int left_height = get_height(rt->left);
        int right_height = get_height(rt->right);
        rt->height = max(left_height,right_height) + 1;
        if(left_height - right_height == 2)
        {
            if(x < rt->left->val) rt = rotate_LL(rt);
            else rt = rotate_LR(rt);
        }
    }
    else
    {
        if(rt->right) rt->right = AVL_insert(rt->right,x);
        else
        {
            BNode<T> *node = new BNode<T>(x);
            node->height = 1;
            rt->right = node;
        }
        int left_height = get_height(rt->left);
        int right_height = get_height(rt->right);
        rt->height = max(left_height,right_height) + 1;
        if(right_height - left_height == 2)
        {
            if(x < rt->right->val) rt = rotate_RL(rt);
            else rt = rotate_RR(rt);
        }
    }
    return rt;
}

template<class T>
BNode<T>* AVLTree<T>::AVL_delete(BNode<T> *rt,T x)
{
    if(!rt) return NULL;
    if(x < rt->val)
    {
        if(rt->left) rt->left = AVL_delete(rt->left,x);
        int left_height = get_height(rt->left);
        int right_height = get_height(rt->right);
        rt->height = max(left_height,right_height) + 1;
        if(right_height - left_height == 2)
        {
            if(get_height(rt->right->right) >= get_height(rt->right->left)) rt = rotate_RR(rt);
            else rt = rotate_RL(rt);
        }
    }
    else if(x == rt->val)
    {
        if(!rt->left && !rt->right)
        {
            BNode<T> *tmp = rt;
            rt = NULL;
            delete tmp;
        }
        else if(!rt->left || !rt->right)
        {
            if(rt->left)
            {
                BNode<T> *left = rt->left;
                delete rt;
                rt = left;
            }
            else if(rt->right)
            {
                BNode<T> *right = rt->right;
                delete rt;
                rt = right;
            }
        }
        else
        {
            BNode<T> *tmp = rt->right;
            while(tmp->left) tmp = tmp->left;
            rt->val = tmp->val;
            rt->right = AVL_delete(rt->right,tmp->val);
            int left_height = rt->left?rt->left->height:0;
            int right_height = rt->right?rt->right->height:0;
            rt->height = max(left_height,right_height) + 1;
            if(left_height - right_height == 2)
            {
                if(get_height(rt->left->left) >= get_height(rt->left->right)) rt = rotate_LL(rt);
                else rt = rotate_LR(rt);
            }
        }
    }
    else
    {
        if(rt->right) rt->right = AVL_delete(rt->right,x);
        int left_height = rt->left?rt->left->height:0;
        int right_height = rt->right?rt->right->height:0;
        rt->height = max(left_height,right_height) + 1;
        if(left_height - right_height == 2)
        {
            if(get_height(rt->left->left) >= get_height(rt->left->right)) rt = rotate_LL(rt);
            else rt = rotate_LR(rt);
        }
    }
    return rt;
}

template<class T>
vector<T> AVLTree<T>::traversal()
{
    vector<T> vec;
    traversal(root,vec);
    return vec;
}

template<class T>
void AVLTree<T>::traversal(BNode<T> *rt,vector<T> &vec)
{
    if(!rt) return;
    traversal(rt->left,vec);
    vec.push_back(rt->val);
    traversal(rt->right,vec);
}

#endif // AVLTREE_H_INCLUDED

//main.cpp
#include <iostream>
#include <ctime>
#include <vector>
#include <algorithm>
#include "avltree.h"
using namespace std;


int main()
{
    //获取乱序的数组a
    int len = 102400;
    int* a = new int[len];
    srand(time(0));
    for(int i = 0; i < len; i++) a[i] = i;
    for(int i = 0; i < len; i++)
    {
        int r = rand()%len;
        swap(a[r],a[i]);
    }
    AVLTree<int> avl;
    //逐个插入,每插入一次检验一次算法正确性
    for(int i = 0; i < len; i++)
    {
        avl.insert(a[i]);
        if(!avl.isAVL())
        {
            cout << "not AVL tree after insert" << endl;
        }
    }
    //vector<int> vec = avl.traversal();
    //for(int i = 0; i < vec.size(); i++) cout << vec[i] << " ";
    cout << endl;
    //删除删除,每删除一次验证一次算法正确性
    for(int i = 0; i < len; i++)
    {
        avl.erase(a[i]);
        if(!avl.isAVL())
        {
            cout << "not AVL tree after delete" << endl;
        }
    }
    delete[] a;
    return 0;
}

posted @ 2020-03-09 14:32  技术流选手  阅读(354)  评论(0编辑  收藏  举报