二叉搜索树 - 节点计数方法输出不正确
简单来说,从这个函数的定义来看,这个函数是用来判断一个二叉搜索树(BST)中每个节点的类型的。我没有遇到错误,但我觉得我的输出结果不对。我有一个辅助方法,可以递归地遍历这个BST,检查每个节点是零、一个还是两个子节点。如果我输入下面这个BST:
22
/ \
12 30
/ \ / \
8 20 25 40
它返回的是0,0,1,但我觉得这不对,应该返回4,0,3吧?因为22、12和30是两个子节点的节点,所以它们应该是2,而8、20、25和40是没有子节点的叶子节点,所以它们应该是0。希望能得到一些帮助!
这是我的代码:
def node_counts(self):
"""
---------------------------------------------------------
Returns the number of the three types of nodes in a BST.
Use: zero, one, two = bst.node_counts()
-------------------------------------------------------
Returns:
zero - number of nodes with zero children (int)
one - number of nodes with one child (int)
two - number of nodes with two children (int)
----------------------------------------------------------
"""
zero, one, two = self.node_counts_aux(self._root)
return zero, one, two
return
def node_counts_aux(self, node):
zero = 0
one = 0
two = 0
if node is None:
return zero, one, two
else:
self.node_counts_aux(node._left)
print(node._value)
self.node_counts_aux(node._right)
if node._left is not None and node._right is not None:
two += 1
elif (node._left is not None and node._right is None) or (node._left is None and node._right is not None):
one += 1
else:
zero += 1
return zero, one, two
我现在正在进行中序遍历,我期待的结果是4,0,3,而不是0,0,1。
1 个回答
1
递归中一个常见的错误是忘记使用返回值。你需要把这些返回值传递回去,才能在最上层“算数”:
def node_counts_aux(self, node):
zero = 0
one = 0
two = 0
if node is None:
return zero, one, two
z, o, t = self.node_counts_aux(node._left)
zero += z
one += o
two += t
z, o, t = self.node_counts_aux(node._right)
zero += z
one += o
two += t
if node._left and node._right:
two += 1
elif node._left or node._right:
one += 1
else:
zero += 1
return zero, one, two
不过,我一般更喜欢用内部函数,而不是辅助函数,并且用列表来代替不同的变量:
def node_counts(self):
counts = [0, 0, 0]
def traverse(node):
if not node:
return
traverse(node._left)
traverse(node._right)
if node._left and node._right:
counts[2] += 1
elif node._left or node._right:
counts[1] += 1
else:
counts[0] += 1
traverse(self._root)
return tuple(counts)
这样更简洁;需要传递的数据更少。可以安全地假设节点的值是有效的。
另外,遍历的顺序其实并不重要。
这是一个可以在你的树上运行的例子:
from collections import namedtuple
def count_node_types(root):
counts = [0, 0, 0]
def traverse(node):
if not node:
return
traverse(node.left)
traverse(node.right)
if node.left and node.right:
counts[2] += 1
elif node.left or node.right:
counts[1] += 1
else:
counts[0] += 1
traverse(root)
return tuple(counts)
if __name__ == "__main__":
Node = namedtuple("Node", "val left right", defaults=[None] * 3)
r""" 22
/ \
12 30
/ \ / \
8 20 25 40 """
root = Node(22, Node(12, Node(8), Node(20)), Node(30, Node(25), Node(40)))
print(count_node_types(root)) # => (4, 0, 3)
... 但最好选择一个测试案例,确保有一个孩子的节点能被正确计数,所以我会去掉 Node(40)
,这样可以确保返回 (3, 1, 2)
。