Given a binary tree, collect a tree's nodes as if you were doing this: Collect and remove all leaves, repeat until the tree is empty.

Example:
Given binary tree 

          1
         / \
        2   3
       / \     
      4   5    

 

Returns [4, 5, 3], [2], [1].

Explanation:

1. Removing the leaves [4, 5, 3] would result in this tree:

          1
         / 
        2          

 

2. Now removing the leaf [2] would result in this tree:

          1          

 

3. Now removing the leaf [1] would result in the empty tree:

          []         

 

 

Returns [4, 5, 3], [2], [1].

 

 

 vector<vector<int>> findLeaves(TreeNode* root) {
        vector<vector<int>> ret;
        removeLeaves(ret, root);
        return ret;
    }
    
    int removeLeaves(vector<vector<int>> & ret, TreeNode* root) {
        if (root == NULL) return 0;
        int d1 = removeLeaves(ret, root->left);
        int d2 = removeLeaves(ret, root->right);
        int lev = max(d1, d2) + 1;
        if (ret.size() <= lev) ret.resize(lev);
        ret[lev - 1].push_back(root->val);
        return lev;
    }