(三)用go实现平衡二叉树

本篇,我们用go简单的实现平衡二叉查找树。具体原理参考大佬博客即可:AVL树(一)之 图文解析 和 C语言的实现

1.节点定义

type AVLNode struct{
    data int
    height int
    left, right *AVLNode
}

2.树的遍历

// 前序遍历
func PreTraverse(p *AVLNode) {
    if p == nil {
        return 
    }
    fmt.Printf("%d:%d ", p.data, p.height)
    if p.left != nil {
        PreTraverse(p.left)
    }
    if p.right != nil {
        PreTraverse(p.right)
    }
}

// 中序遍历
func InTraverse(p *AVLNode) {
    if p == nil {
        return 
    }
    if p.left != nil {
        InTraverse(p.left)
    }
    fmt.Printf("%d ", p.data)
    if p.right != nil {
        InTraverse(p.right)
    }
}

// 后序遍历
func PostTraverse(p *AVLNode) {
    if p == nil {
        return 
    }
    if p.left != nil {
        PostTraverse(p.left)
    }
    if p.right != nil {
        PostTraverse(p.right)
    }
    fmt.Printf("%d ", p.data)
}

3.树的旋转

// LL的旋转
func ll_rotate(k2 *AVLNode) *AVLNode {
    var k1 *AVLNode = k2.left
    k2.left = k1.right
    k1.right = k2

    k2.height = max(height(k2.left), height(k2.right)) + 1
    k1.height = max(height(k1.left), k2.height) + 1

    return k1
}

// RR的旋转
func rr_rotate(k1 *AVLNode) *AVLNode {
    var k2 *AVLNode = k1.right
    k1.right = k2.left
    k2.left = k1

    k1.height = max(height(k1.left), height(k1.right)) + 1
    k2.height = max(height(k2.right), k1.height) + 1

    return k2
}

// LR的旋转
func lr_rotate(k3 *AVLNode) *AVLNode {
    k3.left = rr_rotate(k3.left)
    return ll_rotate(k3)
}

// RL的旋转
func rl_rotate(k1 *AVLNode) *AVLNode {
    k1.right = ll_rotate(k1.right)
    return rr_rotate(k1)
}

4.插入节点

// 插入节点
func Add(p *AVLNode, data int) *AVLNode {
    if p == nil {
        p = new(AVLNode)
        p.data = data
        p.height = 1
        return p
    }

    if data < p.data {
        p.left = Add(p.left, data)
        if height(p.left) - height(p.right) == 2 {
            if data > p.left.data {
                fmt.Println("lr")
                p = lr_rotate(p)
            } else {
                fmt.Println("ll")
                p = ll_rotate(p)
            }
        }
    } else if data > p.data {
        p.right = Add(p.right, data)
        if height(p.right) - height(p.left) == 2{
            if data > p.right.data {
                fmt.Println("rr")
                p = rr_rotate(p)
            } else {
                fmt.Println("rl")
                p = rl_rotate(p)
            }
        }
    } else {
        fmt.Println("Add fail: not allowed same data!")
    }

    p.height = max(height(p.left), height(p.right)) + 1
    fmt.Printf("节点:%d, 高度:%d\n", p.data, p.height)

    return p
}

5.查询节点

// 查询节点
func Find(p *AVLNode, data int) *AVLNode {
    if p.data == data {
        return p
    } else if data < p.data {
        if p.left != nil {
            return Find(p.left, data)
        }
        return nil
    } else {
        if p.right != nil {
            return Find(p.right, data)
        }
        return nil
    }
}

// 最大节点
func maxNode(p *AVLNode) *AVLNode {
    if p == nil {
        return nil
    }
    for p.right != nil {
        p = p.right
    }
    return p
}

// 最小节点
func minNode(p *AVLNode) *AVLNode {
    if p == nil {
        return nil
    }
    for p.left != nil {
        p = p.left
    }
    return p
}

6.删除节点

// 删除节点
func Delete(p *AVLNode, data int) *AVLNode {
    node := Find(p, data)
    if node != nil {
        return delete(p, node)
    }
    return nil
}

func delete(p, node *AVLNode) *AVLNode {
    if node.data < p.data {
        p.left = delete(p.left, node)
        if height(p.right) - height(p.left) == 2 {
            if height(p.right.right) > height(p.right.left) {
                p = rr_rotate(p)
            } else {
                p = rl_rotate(p)
            }
        }
    } else if node.data > p.data {
        p.right = delete(p.right, node)
        if height(p.left) - height(p.right) == 2 {
            if height(p.left.right) > height(p.left.left) {
                p = lr_rotate(p)
            } else {
                p = ll_rotate(p)
            }
        }
    } else {
        // 左右孩子都非空
        if (p.left != nil) && (p.right != nil) {
            if height(p.left) > height(p.right) {
                var max *AVLNode = maxNode(p.left)
                p.data = max.data
                p.left = delete(p.left, max)
            } else {
                var min *AVLNode = minNode(p.right)
                p.data = min.data
                p.right = delete(p.right, min)
            }
        } else {
            if p.left != nil {
                p = p.left
            } else {
                p = p.right
            }
        }
    }

    if p != nil {
        p.height = max(height(p.left), height(p.right)) + 1
    }

    return p

}

7.完整代码

package main

import (
    "fmt"
)

type AVLNode struct{
    data int
    height int
    left, right *AVLNode
}

func max(a, b int) int {
    if a > b {
        return a
    }
    return b
}

func height(p *AVLNode) int {
    if p != nil {
        return p.height
    }
    return 0
} 

// 前序遍历
func PreTraverse(p *AVLNode) {
    if p == nil {
        return 
    }
    
    fmt.Printf("%d:%d ", p.data, p.height)
    if p.left != nil {
        PreTraverse(p.left)
    }
    if p.right != nil {
        PreTraverse(p.right)
    }
}

// 中序遍历
func InTraverse(p *AVLNode) {
    if p == nil {
        return 
    }
    
    if p.left != nil {
        InTraverse(p.left)
    }
    fmt.Printf("%d ", p.data)
    if p.right != nil {
        InTraverse(p.right)
    }
}

// 后序遍历
func PostTraverse(p *AVLNode) {
    if p == nil {
        return 
    }
    
    if p.left != nil {
        PostTraverse(p.left)
    }
    if p.right != nil {
        PostTraverse(p.right)
    }
    fmt.Printf("%d ", p.data)
}


// LL的旋转
func ll_rotate(k2 *AVLNode) *AVLNode {
    var k1 *AVLNode = k2.left
    k2.left = k1.right
    k1.right = k2

    k2.height = max(height(k2.left), height(k2.right)) + 1
    k1.height = max(height(k1.left), k2.height) + 1

    return k1
}

// RR的旋转
func rr_rotate(k1 *AVLNode) *AVLNode {
    var k2 *AVLNode = k1.right
    k1.right = k2.left
    k2.left = k1

    k1.height = max(height(k1.left), height(k1.right)) + 1
    k2.height = max(height(k2.right), k1.height) + 1

    return k2
}

// LR的旋转
func lr_rotate(k3 *AVLNode) *AVLNode {
    k3.left = rr_rotate(k3.left)
    return ll_rotate(k3)
}

// RL的旋转
func rl_rotate(k1 *AVLNode) *AVLNode {
    k1.right = ll_rotate(k1.right)
    return rr_rotate(k1)
}

// 插入节点
func Add(p *AVLNode, data int) *AVLNode {
    if p == nil {
        p = new(AVLNode)
        p.data = data
        p.height = 1
        return p
    }

    if data < p.data {
        p.left = Add(p.left, data)
        if height(p.left) - height(p.right) == 2 {
            if data > p.left.data {
                fmt.Println("lr")
                p = lr_rotate(p)
            } else {
                fmt.Println("ll")
                p = ll_rotate(p)
            }
        }
    } else if data > p.data {
        p.right = Add(p.right, data)
        if height(p.right) - height(p.left) == 2{
            if data > p.right.data {
                fmt.Println("rr")
                p = rr_rotate(p)
            } else {
                fmt.Println("rl")
                p = rl_rotate(p)
            }
        }
    } else {
        fmt.Println("Add fail: not allowed same data!")
    }

    p.height = max(height(p.left), height(p.right)) + 1
    fmt.Printf("节点:%d, 高度:%d\n", p.data, p.height)

    return p
}

// 查询节点
func Find(p *AVLNode, data int) *AVLNode {
    if p.data == data {
        return p
    } else if data < p.data {
        if p.left != nil {
            return Find(p.left, data)
        }
        return nil
    } else {
        if p.right != nil {
            return Find(p.right, data)
        }
        return nil
    }
}

// 最大节点
func maxNode(p *AVLNode) *AVLNode {
    if p == nil {
        return nil
    }
    for p.right != nil {
        p = p.right
    }
    return p
}

// 最小节点
func minNode(p *AVLNode) *AVLNode {
    if p == nil {
        return nil
    }
    for p.left != nil {
        p = p.left
    }
    return p
}
    
// 删除节点
func Delete(p *AVLNode, data int) *AVLNode {
    node := Find(p, data)
    if node != nil {
        return delete(p, node)
    }
    return nil
}

func delete(p, node *AVLNode) *AVLNode {
    if node.data < p.data {
        p.left = delete(p.left, node)
        if height(p.right) - height(p.left) == 2 {
            if height(p.right.right) > height(p.right.left) {
                p = rr_rotate(p)
            } else {
                p = rl_rotate(p)
            }
        }
    } else if node.data > p.data {
        p.right = delete(p.right, node)
        if height(p.left) - height(p.right) == 2 {
            if height(p.left.right) > height(p.left.left) {
                p = lr_rotate(p)
            } else {
                p = ll_rotate(p)
            }
        }
    } else {
        // 左右孩子都非空
        if (p.left != nil) && (p.right != nil) {
            if height(p.left) > height(p.right) {
                var max *AVLNode = maxNode(p.left)
                p.data = max.data
                p.left = delete(p.left, max)
            } else {
                var min *AVLNode = minNode(p.right)
                p.data = min.data
                p.right = delete(p.right, min)
            }
        } else {
            if p.left != nil {
                p = p.left
            } else {
                p = p.right
            }
        }
    }

    if p != nil {
        p.height = max(height(p.left), height(p.right)) + 1
    }

    return p

}


func main() {
    //num := []int{50, 30, 20, 25, 70, 90, 100}  
    num := []int{3, 2, 1, 4, 5, 6, 7, 16, 15, 14, 13, 12, 11, 10, 8, 9}

    var root *AVLNode
    for _, v := range num {
        fmt.Printf("插入节点:%d\n", v)
        root = Add(root, v)
    }

    fmt.Println("前序遍历:")
    PreTraverse(root)
    fmt.Printf("\n")

    fmt.Println("中序遍历:")
    InTraverse(root)
    fmt.Printf("\n")

    fmt.Println("后序遍历:")
    PostTraverse(root)
    fmt.Printf("\n")

    avlnode := Find(root, 60)
    if avlnode != nil {
        fmt.Println("查询结果:")
        fmt.Printf("节点:%d 左子节点:%d 右子节点:%d\n", avlnode.data, avlnode.left.data, avlnode.right.data)
    }

    root = Delete(root, 8)
    fmt.Println("删除后前序遍历:")
    PreTraverse(root)
    fmt.Printf("\n")

    fmt.Println("删除后中序遍历:")
    InTraverse(root)
    fmt.Printf("\n")


}

posted @ 2021-10-29 15:41  qxcheng  阅读(88)  评论(0编辑  收藏  举报