一道关于二叉树的字节面试题的思考
技术人的精神,就是追根究底,把一个事情彻底弄清楚吧!
题目
众所周知,字节在一二面的末尾,会随机抽一道算法题,当场写代码。我抽到的题目如下:
二叉树根节点到叶子节点的所有路径和。给定一个仅包含数字 0−9 的二叉树,每一条从根节点到叶子节点的路径都可以用一个数字表示。例如根节点到叶子节点的一条路径是 1→2→3,那么这条路径就用 123 来代替。找出根节点到叶子节点的所有路径表示的数字之和。
例如:这棵二叉树一共有两条路径,根节点到左叶子节点的路径 12 代替,根节点到右叶子节点的路径用 13 代替。所以答案为12+13=25 。
递归解法
看到这个题目,首先想到的,其实就是找到这个二叉树的从根节点到叶子节点的所有路径。而要找到所有路径,第一想到的肯定是递归。通过左子树的递归拿到的路径、右子树的递归拿到的路径,以及根节点,得出最终的所有路径。
算法如下:
STEP1:如果已经是叶子节点,那么构造一条路径列表,该路径只有一个元素即叶子节点的值,然后返回【退出条件】。
STEP2: 递归找到左子树到叶子节点的所有路径列表。对于每条路径,将根节点加入,从而得到新的结果路径,并加入;
STEP3:递归找到右子树到叶子节点的所有路径列表。对于每条路径,将根节点加入,从而得到新的结果路径,并加入;
STEP4: 将左右子树的所有路径合并成最终的路径列表【组合子问题的解】。
有两点说明下:
- 由于节点的值只有 0-9,因此,可以直接用字符串来表示路径。如果用 List[Integer] ,更灵活,不过会变成列表的列表,处理起来会有点绕。
- 构建路径时,使用的是 StringBuilder 的 append 方法,而不是 insert 方法,因此构造的路径是逆序的。主要考虑到 insert 方法会导致数组频繁移动,效率低。具体可以看 StringBuilder 实现。
递归代码如下:
public List<Path> findAllPaths(TreeNode root) {
List<Path> le = new ArrayList<>();
List<Path> ri = new ArrayList<>();
if (root != null) {
if (root.left == null && root.right == null) {
List<Path> single = new ArrayList<>();
single.add(new Path(root.val));
return single;
}
if (root.left != null) {
le = findAllPaths(root.left);
for (Path p: le) {
p.append(root.val);
}
}
if (root.right != null) {
ri = findAllPaths(root.right);
for (Path p: ri) {
p.append(root.val);
}
}
}
List<Path> paths = new ArrayList<>();
paths.addAll(le);
paths.addAll(ri);
return paths;
}
class Path {
StringBuilder s = new StringBuilder();
public Path() { }
public Path(Integer i) {
s.append(i);
}
public Path(List list) {
list.forEach( e-> {
s.append(e);
});
}
public Path(String str) { this.s = new StringBuilder(str); }
public Long getValue() {
return Long.parseLong(s.reverse().toString());
}
public StringBuilder append(Integer i) {
return s.append(i);
}
public String toString() {
return s.reverse().toString();
}
}
class TreeNode {
int val;
TreeNode left;
TreeNode right;
TreeNode(int x) { val = x; }
public int height() {
if (left == null && right == null) {
return 1;
}
int leftHeight = 0;
int rightHeight = 0;
if (left != null) {
leftHeight = left.height();
}
if (right != null) {
rightHeight = right.height();
}
return 1 + max(leftHeight, rightHeight);
}
}
关键点
实际上,我在面试当场没有做出来,但在面试后的十分钟,我就把代码写出来了。可能在面试的时候有点紧张,有个地方一直卡住了。
类似二叉树、动态规划的问题,由于有多条分支,从思维上来说,不像处理数组、链表那样是一种线性思维,而是需要一种非线性思维,因此,多做类似的题目,对思维的锻炼是很有益的,—— 能够帮助人摆脱固有的线性思维。
一般来说,算法问题,通常可以分为两步:1. 划分子问题; 2. 将子问题的解组合成原问题的解。 划分子问题,相对容易一点,但如果划分不合理,就难以想清楚如何去组合解。我一开始就想到了要用子树的解与根节点来组合,但是一直纠结在对求出单条路径的思考上,而不是把所有路径作为子问题的解。这样,我就难以想到如何去组合得到最终解。但面试结束之后,我脑子里闪过左子树的所有路径列表,顿时就明白如何组合了。因此,有时,把“所有”作为子问题的解,再跟上层节点组合,反而能容易地得到原问题的解。此外,递归要特别注意退出条件。
推荐可以多做二叉树、动态规划的题目,能够很好地锻炼划分子问题、组合子问题的解来求解的技能。
非递归算法
实现递归解法,只是一个开始。递归算法很简洁,但执行效率很低,而且容易栈溢出。如果一个足够大的二叉树,就能让递归代码无法执行下去。因此,需要寻求非递归实现。
非递归实现,往往需要借助于栈。我们需要模拟一下如何用栈来访问二叉树。如下图所示:
可以先找找规则,往往规则就是代码的路径。
- 每次走到一个节点,先将节点值入栈;
- 走到叶子节点时,说明已经走到路径的尽头,可以记录下这条路径。
第一版非递归实现如下。用一个栈来存储二叉树的访问节点。如果是叶子节点,就记录路径,然后将叶子节点出栈,继续访问。
public List<Path> findAllPathsNonRecDeadLoop(TreeNode root) {
List<Path> allPaths = new ArrayList<>();
Stack<Integer> s = new DyStack<Integer>();
TreeNode p = root;
while(p != null) {
s.push(p.val);
if (p.left == null && p.right == null) {
allPaths.add(new Path(s.unmodifiedList()));
s.pop();
if (s.isEmpty()) {
break;
}
}
if (p.left != null) {
p = p.left;
}
else if (p.right != null) {
p = p.right;
}
}
return allPaths;
}
不过,这个代码实现会陷入死循环。为什么呢?因为它会无止境重复进入左子树,而且回溯的时候,也没法找到父节点。
回溯
为了解决死循环的问题,我们需要加一些支持:进入某个节点时,必须记下该节点的父节点,以及该父节点是否访问过左右子树。这个信息用 TraceNode 来表示。由于始终需要回溯,因此,TraceNode 必须放在栈中,在适当的时候弹出,就像保存现场一样。当遍历的时候,需要记录已经访问的节点,不重复访问,也需要避免将中间节点重复压栈。
重新理一下。对于当前节点,有四种情形需要考虑:
- 当前节点是叶子节点。记录路径、出栈 treeData, 出栈 traceNode ,回溯到父节点;
- 当前节点不是叶子节点,有左子树,则需要记录该节点指针及左子树已访问,并进入左子树;
- 当前节点不是叶子节点,有右子树,则需要记录该节点指针及右子树已访问,并进入右子树;
- 当前节点不是叶子节点,有左右子树且均已访问,出栈 treeData, 出栈 traceNode ,回溯到父节点。
第二版的递归实现如下:
public List<Path> findAllPathsNonRec(TreeNode root) {
List<Path> allPaths = new ArrayList<>();
Stack<Integer> treeData = new DyStack<>();
Stack<TraceNode> trace = new DyStack<>();
TreeNode p = root;
TraceNode traceNode = TraceNode.getNoAccessedNode(p);
while(p != null) {
if (p.left == null && p.right == null) {
// 叶子节点的情形,需要记录路径,并回溯到父节点
treeData.push(p.val);
allPaths.add(new ListPath(treeData.unmodifiedList()));
treeData.pop();
if (treeData.isEmpty()) {
break;
}
traceNode = trace.pop();
p = traceNode.getParent();
continue;
}
else if (traceNode.needAccessLeft()) {
// 需要访问左子树的情形
treeData.push(p.val);
trace.push(TraceNode.getLeftAccessedNode(p));
p = p.left;
}
else if (traceNode.needAccessRight()) {
// 需要访问右子树的情形
if (traceNode.hasNoLeft()) {
treeData.push(p.val);
}
if (!traceNode.hasAccessedLeft()) {
// 访问左节点时已经入栈过,这里不重复入栈
treeData.push(p.val);
}
trace.push(TraceNode.getRightAccessedNode(p));
p = p.right;
if (p.left != null) {
traceNode = TraceNode.getNoAccessedNode(p);
}
else if (p.right != null) {
traceNode = TraceNode.getLeftAccessedNode(p);
}
}
else if (traceNode.hasAllAccessed()) {
// 左右子树都已经访问了,需要回溯到父节点
if (trace.isEmpty()) {
break;
}
treeData.pop();
traceNode = trace.pop();
p = traceNode.getParent();
}
}
return allPaths;
}
class TraceNode {
private TreeNode parent;
private int accessed; // 0 均未访问 1 已访问左 2 已访问右
public TraceNode(TreeNode parent, int accessed) {
this.parent = parent;
this.accessed = accessed;
}
public static TraceNode getNoAccessedNode(TreeNode parent) {
return new TraceNode(parent, 0);
}
public static TraceNode getLeftAccessedNode(TreeNode parent) {
return new TraceNode(parent, 1);
}
public static TraceNode getRightAccessedNode(TreeNode parent) {
return new TraceNode(parent, 2);
}
public boolean needAccessLeft() {
return parent.left != null && accessed == 0;
}
public boolean needAccessRight() {
return parent.right != null && accessed < 2;
}
public boolean hasAccessedLeft() {
return parent.left == null || (parent.left != null && accessed == 1);
}
public boolean hasNoLeft() {
return parent.left == null;
}
public boolean hasAllAccessed() {
if (parent.left != null && parent.right == null && accessed == 1) {
return true;
}
if (parent.right != null && accessed == 2) {
return true;
}
return false;
}
public TreeNode getParent() {
return parent;
}
public int getAccessed() {
return accessed;
}
}
关于是否已访问左右子树的判断都隐藏在 TraceNode 里,findAllPathsNonRec 方法不感知这个。后续如果觉得用 int 来表示 accessed 空间效率不高,可以内部重构,对 findAllPathsNonRec 无影响。这就是封装的益处。
测试
递归代码和非递归代码都是容易有 BUG 的,需要仔细测试下。测试用例通常至少要包括:
- C1: 单个根节点树;
- C2: 单个根节点 + 左节点;
- C3: 单个根节点 + 右节点;
- C4: 单个根节点 + 左右节点;
- C5: 普通的二叉树,左右随机;
- 复杂的二叉树,非常大。
如何构造复杂的二叉树呢?可以采用构造法。基于简单的 C2,C3,C4,将一棵树的根节点连接到另一棵树的左叶子节点或右叶子节点上。复杂结构总是由简单结构来组合而成。
测试代码如下。用 TreeBuilder 注解来表示构造的二叉树,从而能够批量拿到这些方法构造的树,进行测试。
public static void main(String[] args) {
TreePathSum treePathSum = new TreePathSum();
Method[] methods = treePathSum.getClass().getDeclaredMethods();
for (Method m: methods) {
if (m.isAnnotationPresent(TreeBuilder.class)) {
try {
TreeNode t = (TreeNode) m.invoke(treePathSum, null);
System.out.println("height: " + t.height());
treePathSum.test2(t);
} catch (Exception ex) {
System.err.println(ex.getMessage());
}
}
}
}
public void test(TreeNode root) {
System.out.println("Rec Implementation");
List<Path> paths = findAllPaths(root);
Long sum = paths.stream().collect(Collectors.summarizingLong(Path::getValue)).getSum();
System.out.println(paths);
System.out.println(sum);
System.out.println("Non Rec Implementation");
List<Path> paths2 = findAllPathsNonRec(root);
Long sum2 = paths2.stream().collect(Collectors.summarizingLong(Path::getValue)).getSum();
System.out.println(paths2);
System.out.println(sum2);
assert sum == sum2;
}
public void test2(TreeNode root) {
System.out.println("Rec Implementation");
List<Path> paths = findAllPaths(root);
System.out.println(paths);
System.out.println("Non Rec Implementation");
List<Path> paths2 = findAllPathsNonRec(root);
System.out.println(paths2);
assert paths.size() == paths2.size();
for (int i=0; i < paths.size(); i++) {
assert paths.get(i).toString().equals(paths2.get(i).toString());
}
}
@TreeBuilder
public TreeNode buildTreeOnlyRoot() {
TreeNode tree = new TreeNode(9);
return tree;
}
@TreeBuilder
public TreeNode buildTreeWithL() {
return buildTreeWithL(5, 1);
}
public TreeNode buildTreeWithL(int rootVal, int leftVal) {
TreeNode tree = new TreeNode(rootVal);
TreeNode left = new TreeNode(leftVal);
tree.left = left;
return tree;
}
@TreeBuilder
public TreeNode buildTreeWithR() {
return buildTreeWithR(5,2);
}
public TreeNode buildTreeWithR(int rootVal, int rightVal) {
TreeNode tree = new TreeNode(rootVal);
TreeNode right = new TreeNode(rightVal);
tree.right = right;
return tree;
}
@TreeBuilder
public TreeNode buildTreeWithLR() {
return buildTreeWithLR(5,1,2);
}
public TreeNode buildTreeWithLR(int rootVal, int leftVal, int rightVal) {
TreeNode tree = new TreeNode(rootVal);
TreeNode left = new TreeNode(leftVal);
TreeNode right = new TreeNode(rightVal);
tree.right = right;
tree.left = left;
return tree;
}
Random rand = new Random(System.currentTimeMillis());
@TreeBuilder
public TreeNode buildTreeWithMore() {
TreeNode tree = new TreeNode(5);
TreeNode left = new TreeNode(1);
TreeNode right = new TreeNode(2);
TreeNode left2 = new TreeNode(3);
TreeNode right2 = new TreeNode(4);
tree.right = right;
tree.left = left;
left.left = left2;
left.right = right2;
return tree;
}
@TreeBuilder
public TreeNode buildTreeWithMore2() {
TreeNode tree = new TreeNode(5);
TreeNode left = new TreeNode(1);
TreeNode right = new TreeNode(2);
TreeNode left2 = new TreeNode(3);
TreeNode right2 = new TreeNode(4);
tree.right = right;
tree.left = left;
right.left = left2;
right.right = right2;
return tree;
}
public TreeNode treeWithRandom() {
int c = rand.nextInt(3);
switch (c) {
case 0: return buildTreeWithL(rand.nextInt(9), rand.nextInt(9));
case 1: return buildTreeWithR(rand.nextInt(9), rand.nextInt(9));
case 2: return buildTreeWithLR(rand.nextInt(9), rand.nextInt(9), rand.nextInt(9));
default: return buildTreeOnlyRoot();
}
}
public TreeNode linkRandom(TreeNode t1, TreeNode t2) {
if (t2.left == null) {
t2.left = t1;
}
else if (t2.right == null) {
t2.right = t1;
}
else {
int c = rand.nextInt(4);
switch (c) {
case 0: t2.left.left = t1;
case 1: t2.left.right = t1;
case 2: t2.right.left = t1;
case 3: t2.right.right = t1;
default: t2.left.left = t1;
}
}
return t2;
}
@TreeBuilder
public TreeNode buildTreeWithRandom() {
TreeNode root = treeWithRandom();
int i = 12;
while (i > 0) {
TreeNode t = treeWithRandom();
root = linkRandom(root, t);
i--;
}
return root;
}
经测试,发现第二版非递归程序在某种情况下,还是有 BUG 。这说明某些基本情形还是没覆盖到。用如下测试用例调试,发现就有问题:
@TreeBuilder
public TreeNode buildTreeWithMore4() {
TreeNode tree = new TreeNode(5);
TreeNode left = new TreeNode(1);
TreeNode right = new TreeNode(2);
TreeNode left2 = new TreeNode(3);
TreeNode right2 = new TreeNode(4);
TreeNode right3 = new TreeNode(6);
tree.right = right;
tree.left = left;
left.right = right3;
right.right = left2;
left2.right = right2;
return tree;
}
回溯再思考
问题在哪里?当初次进入没有左子树的右子树时,会有问题。这说明,我还没有真正弄明白整个回溯过程。重新再理一下回溯过程:
- 有一个用来指向当前访问节点的指针 p ;
- 有一个用来存储已访问节点值的栈 treeData;
- 有一个用来回溯的存储最近一次访问的节点信息的栈 trace ;
- 有一个用来指明往哪个方向走的 traceNode 。
问题在于我没想清楚 traceNode 到底是什么含义。 traceNode 的 parent 和 accessed 到底该存放什么。实际上,traceNode 和 p 是配套使用的。p 是当前进入的节点的指针,而 traceNode 用来指明进入 p 之后,该往哪里走。 traceNode 的来源应该有两个:
- 第一次进入 p 时,这时候,左右子树都没有访问过,parent 应该与 p 相同,而 accessed 总是初始化为 0 ;
- 访问了 p 的左子树或右子树,回溯进入 p 时,这时候 parent 应该是 p 的父节点,从 trace 里拿到。
第二版非递归程序正是没有考虑到第一次进入 p 时的情况。 如下代码所示。当 p 进入左子树时,需要将最近一次的父节点信息入栈 trace ,同时需要将 traceNode 设置为初始进入 p 时的情形。进入右子树类似。这一点正是第二版非递归程序没有想清楚的地方。
trace.push(TraceNode.getLeftAccessedNode(p));
p = p.left;
traceNode = TraceNode.getNoAccessedNode(p);
我们做一些修改,得到了第三版非递归程序。经测试是 OK 的。
public List<Path> findAllPathsNonRec(TreeNode root) {
List<Path> allPaths = new ArrayList<>();
Stack<Integer> treeData = new DyStack<>();
Stack<TraceNode> trace = new DyStack<>();
TreeNode p = root;
TraceNode traceNode = TraceNode.getNoAccessedNode(p);
while(p != null) {
if (p.left == null && p.right == null) {
// 叶子节点的情形,需要记录路径,并回溯到父节点
treeData.push(p.val);
allPaths.add(new ListPath(treeData.unmodifiedList()));
treeData.pop();
if (treeData.isEmpty()) {
break;
}
traceNode = trace.pop();
p = traceNode.getParent();
continue;
}
else if (traceNode.needAccessLeft()) {
// 需要访问左子树的情形
treeData.push(p.val);
trace.push(TraceNode.getLeftAccessedNode(p));
p = p.left;
traceNode = TraceNode.getNoAccessedNode(p);
}
else if (traceNode.needAccessRight()) {
// 需要访问右子树的情形
if (traceNode.hasNoLeft()) {
treeData.push(p.val);
}
if (!traceNode.hasAccessedLeft()) {
// 访问左节点时已经入栈过,这里不重复入栈
treeData.push(p.val);
}
trace.push(TraceNode.getRightAccessedNode(p));
p = p.right;
traceNode = TraceNode.getNoAccessedNode(p);
}
else if (traceNode.hasAllAccessed()) {
// 左右子树都已经访问了,需要回溯到父节点
if (trace.isEmpty()) {
break;
}
treeData.pop();
traceNode = trace.pop();
p = traceNode.getParent();
}
}
return allPaths;
}
优化
扩展性
由于题目中所给的节点值为 0-9, 因此,前面取巧用了字符串来表示路径。如果节点值不为 0-9 呢?如果依然要用字符串表示,则需要分隔符。现在,我们用列表来表示路径。封装的好处,就在于可以替换实现,而尽量少地改变客户端代码(在这里是findAllPaths 和 findAllPathsNonRec 方法)。
这里,Path 类改成接口,原来的 Path 类改成 StringPath ,然后用 StringPath 替换 Path 。 将原来 StringPath 用到的方法,定义成接口方法。只用到了 append 和 getValue 方法。不过,构造器方法参数也要兼容。这样,只要把原来的 StringPath 改成 ListPath ,其它基本不用动,就可以运行通过。
interface Path {
void append(Integer i);
Long getValue();
}
class StringPath implements Path { // code as before }
class ListPath implements Path {
List<Integer> path = new ArrayList<>();
public ListPath(int i) {
this.path.add(i);
}
public ListPath(List list) {
this.path.addAll(list);
}
@Override
public void append(Integer i) {
path.add(i);
}
@Override
public Long getValue() {
StringBuilder s = new StringBuilder();
path.forEach( e-> {
s.append(e);
});
return Long.parseLong(s.reverse().toString());
}
public String toString() {
return StringUtils.join(path.toArray(), "");
}
}
小结
花了一天弄懂二叉树回溯的玩法。技术人的精神,就是追根究底,把一个事情彻底弄清楚吧!
在本文中,我们通过一个二叉树的路径寻找面试题,讨论了递归和非递归解法,探讨了非递归过程中遇到的问题,模拟了二叉树的回溯,对于理解二叉树的访问是很有益的。而对于回溯算法的理解,锻炼了非线性思维。此外,当程序有 BUG 时,往往是某个方面没想得足够明白导致。坚持思考,清晰定义,就能向正确再迈进一步。
不看答案,自己弄明白一个问题,收获大大的!
本文完整源代码见:ALLIN 工程里:zzz.study.datastructure.tree.TreePathSum