leetcode 501. Find Mode in Binary Search Tree

Given a binary search tree (BST) with duplicates, find all the mode(s) (the most frequently occurred element) in the given BST.

Assume a BST is defined as follows:

  • The left subtree of a node contains only nodes with keys less than or equal to the node's key.
  • The right subtree of a node contains only nodes with keys greater than or equal to the node's key.
  • Both the left and right subtrees must also be binary search trees.

For example:
Given BST [1,null,2,2],

   1
    \
     2
    /
   2

return [2].

Note: If a tree has more than one mode, you can return them in any order.

Follow up: Could you do that without using any extra space? (Assume that the implicit stack space incurred due to recursion does not count).

粗暴解法,直接hash计数然后找出最大计数的值。

# Definition for a binary tree node.
# class TreeNode(object):
#     def __init__(self, x):
#         self.val = x
#         self.left = None
#         self.right = None

class Solution(object):
    def findMode(self, root):
        """
        :type root: TreeNode
        :rtype: List[int]
        """             
        def dfs(node, cnt):    
            if not node: return
            dfs(node.left, cnt)
            cnt[node.val] += 1
            dfs(node.right, cnt)        
        cnt = collections.defaultdict(int)   
        dfs(root, cnt)
        ans,max_cnt = [],0
        for k,v in cnt.items():
            if v > max_cnt:
                max_cnt = v
                ans = [k]
            elif v == max_cnt and k not in ans:
                ans.append(k)
        return ans        

最后几行可以直接使用python max :

mc = max(cnt.values())
return [n for n, c in cnt.items() if c == mc]

 

另外就是经典的tree node遍历解法,在dfs时候使用pre_node记录上次遍历的node,和当前node值进行比较:

# Definition for a binary tree node.
# class TreeNode(object):
#     def __init__(self, x):
#         self.val = x
#         self.left = None
#         self.right = None

class Solution(object):
    def findMode(self, root):
        """
        :type root: TreeNode
        :rtype: List[int]
        """             
        ans,max_cnt = [],0
        pre_node, pre_cnt = None, 1
        
        def dfs(node):    
            nonlocal ans,max_cnt,pre_node,pre_cnt
            if not node: return
            dfs(node.left)
            if not pre_node: # init
                max_cnt = 1
                ans = [node.val]
            else:
                if node.val == pre_node.val:
                    pre_cnt += 1
                else:
                    pre_cnt = 1
                if pre_cnt > max_cnt:
                    max_cnt = pre_cnt
                    ans = [node.val]
                elif pre_cnt == max_cnt:
                    ans.append(node.val)                
            pre_node = node
            dfs(node.right)                
        
        dfs(root)    
        
        return ans        

 python2 下的解法,合理运用dummy value其实非常方便哦!

class Solution(object):
    def findMode(self, root):
        """
        :type root: TreeNode
        :rtype: List[int]
        """
        if root is None:
            return []
        
        self.curVal = root.val - 1 # dummy value is good!
        self.curNum = 0 # dummy value is good!
        self.maxNum = 0
        self.maxVals = []
        
        def dfs(root):
            if root is not None:
            
                dfs(root.right)
            
                if root.val != self.curVal:
                    self.curNum = 0
                self.curNum = self.curNum + 1
                self.curVal = root.val
                if self.curNum == self.maxNum:
                    self.maxVals.append(self.curVal)
                elif self.curNum > self.maxNum:
                    self.maxNum = self.curNum
                    self.maxVals = [self.curVal]                    
                
                dfs(root.left)
        
        dfs(root)
        return self.maxVals

使用stack的解法:

class Solution(object):
    def findMode(self, root):
        """
        :type root: TreeNode
        :rtype: List[int]
        """
        stack, node, prev, cnt, res = [], root, None, 0, (0, [])
        while stack or node:
            if node:
                stack.append(node)
                node = node.left
            else:
                node = stack.pop()
                if node.val != prev:
                    cnt = 0
                cnt += 1
                if cnt > res[0]:
                    res = (cnt, [node.val])
                elif cnt == res[0]:
                    res[1].append(node.val)
                prev = node.val
                node = node.right
        return res[1]

 

 

posted @ 2018-06-06 22:22  bonelee  阅读(280)  评论(1编辑  收藏  举报