二叉树的序列化与反序列化
二叉树的序列化与反序列化
字节面试
在我的博客树的遍历中介绍了利用递归建树的算法解决如何从树的先序+中序序列恢复原树的结构。考虑到当时还是debug了很久,这次面试写代码时我转化了思路。面试时在牛客网的平台上没有写出反序列化的完整代码,向面试官解释了代码思想。今天整理了一番。
序列化
将二叉树看作类似堆的完全二叉树的结构,这样从根节点自顶向下,假设code[i]代表节点i的编号,我们规定
\[左孩子:code[lson] = code[i]*2 \\
右孩子:code[rson] = code[i]*2 + 1
\]
通过任意一种树的遍历方式,我们得到每个节点编号和节点值,于是得到整棵树的序列化。
反序列化
有了全部的节点位置信息和节点值,如何重新建树呢?
面试时我写的算法是自下而上的建树方式,原因在于除了根节点每个节点都会有父亲,我们只需要把节点编号/2即可找到父亲,然后判断该节点是左儿子/右儿子,然后向父亲节点添加该节点。建树的过程相当于把孤立的点不断合并形成一棵树。面试时一下子没想清楚如何一次遍历全部节点完成操作,代码未能完全实现。
换一种自顶向下的思路会更容易实现。
建树相当于加上全部的父子关系的边。对全部节点编号从小到大排序后,我们依次加边即可。
判断第i个节点是否有孩子,如果有就加上相应的边(父亲的左右儿子指向该节点)。完成全部加边操作后,返回根节点,即是整棵树的引用。
class Tree:
def __init__(self, val):
self.val = val
self.lson = None
self.rson = None
# 判断两棵树是否等同
def equals(self, tree):
return self.__repr__()==tree.__repr__()
# 打印整棵树
def __repr__(self, dep=0):
rep = '\t'*dep + str(self.val) + '\n'
if self.lson!=None:
rep += self.lson.__repr__(dep+1)
if self.rson!=None:
rep += self.rson.__repr__(dep+1)
return rep
def serialize(self, root=1):
"""
二叉树的序列化
返回节点编号 节点值
"""
codes = [root]
vals = [self.val]
if self.lson:
tmpCodes, tmpVals = self.lson.serialize(root*2)
codes += tmpCodes
vals += tmpVals
if self.rson:
tmpCodes, tmpVals = self.rson.serialize(root*2+1)
codes += tmpCodes
vals += tmpVals
return codes, vals
def deserialize(codes, vals):
"""
二叉树的反序列化
根据节点编号 节点值
重建原树
"""
treeNode = dict(zip(codes, map(Tree, vals)))
codes.sort()
# 从根节点自顶向下加边
for rt in codes:
lson = rt*2
rson = rt*2 + 1
if lson in treeNode:
treeNode[rt].lson = treeNode[lson]
if rson in treeNode:
treeNode[rt].rson = treeNode[rson]
return treeNode[1]
测试
我们要保证 deserialize(serialize(T)) == T
测试代码:
def test():
"""
Tree t:
1
/ \
2 5
/ \ / \
8 6 7 3
\
100
/
20
"""
t = Tree(1)
t.lson = Tree(2)
t.rson = Tree(5)
t.lson.lson = Tree(8)
t.lson.rson = Tree(6)
t.rson.lson = Tree(7)
t.rson.rson = Tree(3)
t.lson.rson.rson = Tree(100)
t.lson.rson.rson.lson = Tree(20)
# print(t)
# 序列化
series = t.serialize()
# codes, vals = series
# print(codes)
# print(vals)
# 反序列化
tt = Tree.deserialize(*series)
print('before:')
print(t)
print('after:')
print(tt)
print('是否相同:', tt.equals(t))
if __name__=='__main__':
test()
测试结果
before:
1
2
8
6
100
20
5
7
3
after:
1
2
8
6
100
20
5
7
3
是否相同: True
(完)