python 二叉树

from __future__ import annotations


class TreeNode:
    def __init__(self, value, l_node: TreeNode = None, r_node: TreeNode = None):
        self.value = value
        self.l_node = l_node
        self.r_node = r_node

    def get_value(self):
        return self.value

    def set_l_node(self, node: TreeNode) -> None:
        self.l_node = node

    def set_r_node(self, node: TreeNode) -> None:
        self.r_node = node


class GenTree:
    def __init__(self, values: list) -> None:
        self.root = None
        self.values = values
        self.gen_tree()

    def get_root(self) -> TreeNode:
        return self.root

    def gen_tree(self):
        for value in self.values:
            if self.root is None:
                self.root = TreeNode(value)
            else:
                self._gen_tree(self.root, value)

    def _gen_tree(self, node: TreeNode, value):
        # if value < node.value:
        #     if node.l_node is None:
        #         node.l_node = TreeNode(value)
        #     else:
        #         self._gen_tree(node.l_node, value)
        # else:
        #     if node.r_node is None:
        #         node.r_node = TreeNode(value)
        #     else:
        #         self._gen_tree(node.r_node, value)

        def __gen_tree(l_or_r, value):
            new_node = getattr(node, l_or_r)
            if new_node is None:
                setattr(node, l_or_r, TreeNode(value))
            else:
                return self._gen_tree(new_node, value)

        if value < node.value:
            return __gen_tree("l_node", value)
        else:
            return __gen_tree("r_node", value)


def pre_traverse_tree(node: TreeNode):
    if node is not None:
        yield node.value
        yield from pre_traverse_tree(node.l_node)
        yield from pre_traverse_tree(node.r_node)


def in_traverse_tree(node: TreeNode):
    if node is not None:
        yield from pre_traverse_tree(node.l_node)
        yield node.value
        yield from pre_traverse_tree(node.r_node)


root = GenTree([2, 3, 1, 4, 5]).get_root()
print(list(pre_traverse_tree(root)))
print(list(in_traverse_tree(root)))

posted @ 2019-10-24 13:46  两只老虎111  阅读(228)  评论(0编辑  收藏  举报