在Python中对树的所有元素应用函数

-1 投票
1 回答
58 浏览
提问于 2025-04-14 16:36

这里有一种像树一样的数据结构,它的每个节点里包含一些简单的值(比如整数)、列表和字典,还有np.array(这是NumPy库中的一种数据结构,用于处理数组)。我需要写一个函数(用Python语言),这个函数可以对树的较低层级进行聚合操作,也就是说,把这些层级的值合并起来,然后用这个合并的结果替换掉原来的值。最后,这个函数会返回一个新的数据结构。

举个例子。

Initial data structure: {“a”: [1, [2, 3, 4], [5, 6, 7]], “b”: [{“c”:8, “d”:9}, {“ e”:3, “f”:4}, 8]}
Aggregation function:   sum

First use:  {“a”: [1, 9, 18]], “b”: [17, 7, 8]}
Second use: {“a”: 28, “b”: 32}
Third use:  60
Fourth use: 60

1 个回答

-1

sum 是一个特别简单的例子,因为你可以以任何顺序进行加法,结果总是一样的。

举个例子,你可以先把树形结构压缩成一个线性集合,使用深度优先搜索或广度优先搜索,然后再把所有的数字加起来,结果还是一样的。

但是如果换成其他的聚合函数,数据的分组方式可能就会影响结果。

我建议使用两个不同的函数,collapse_then_aggregateaggregate_recursively,这两个函数对于像 sum 这样的简单聚合函数会返回相同的结果,但对于更复杂的聚合函数可能会返回不同的结果。

注意 depth_first_searchaggregate_recursively 是如何使用相同的逻辑来递归遍历树的,它们会对每个可迭代对象的元素进行递归调用。字典会单独处理,因为我们关注的是值而不是键,所以用 if isinstance(tree, dict) 来判断。可迭代对象会使用 try/except 来处理。字符串也是可迭代的,但我们通常不想逐个字符地遍历它们,所以我写了一个特殊的情况 if isinstance(tree, (str, bytes)),这样字符串就不会被当作可迭代对象处理。

def depth_first_search(tree):
    if isinstance(tree, dict):
        for subtree in tree.values():
            yield from depth_first_search(subtree)
    elif isinstance(tree, (str, bytes)):
        yield tree
    else:
        try:
            tree_iter = iter(tree)
        except TypeError:
            yield tree
        else:
            for subtree in tree_iter:
                yield from depth_first_search(subtree)

def collapse_then_aggregate(f, tree):
    return f(depth_first_search(tree))

def aggregate_recursively(f, tree):
    if isinstance(tree, dict):
        return f(aggregate_recursively(f, subtree) for subtree in tree.values())
    elif isinstance(tree, (str, bytes)):
        return tree
    else:
        try:
            tree_iter = iter(tree)
        except TypeError:
            return tree
        return f(aggregate_recursively(f, subtree) for subtree in tree_iter)

应用示例:

from math import prod
from statistics import mean, geometric_mean

tree = {'a': [1, [2, 3, 4], [5, 6, 7, 8]], 'b': [{'c':8, 'd':9}, {'e':3, 'f':4}, 8]}

for f in (sum, prod, mean, geometric_mean, list):
    for fold in (collapse_then_aggregate, aggregate_recursively):
        result = fold(f, tree)
        print(f'{f.__name__:4.4}  {fold.__name__:23}  {result}')

结果:

sum   collapse_then_aggregate  68
sum   aggregate_recursively    68
prod  collapse_then_aggregate  278691840
prod  aggregate_recursively    278691840
mean  collapse_then_aggregate  5.230769230769231
mean  aggregate_recursively    5.083333333333334
geom  collapse_then_aggregate  4.462980019474007
geom  aggregate_recursively    4.03915728944794
list  collapse_then_aggregate  [1, 2, 3, 4, 5, 6, 7, 8, 8, 9, 3, 4, 8]
list  aggregate_recursively    [[1, [2, 3, 4], [5, 6, 7, 8]], [[8, 9], [3, 4], 8]]

注意 sumprod 在我们的两种聚合方法中给出的结果是相同的,但 mean, geometric_meanlist 的结果则不同。

撰写回答