找到根据函数加权的列表或集合中的所有最小元素

4 投票
2 回答
1342 浏览
提问于 2025-04-18 16:46

如果我想从某个函数 f 给定的列表或集合 x 中找到一个最小值,我可以使用一些方便的单行代码,比如:

min(x,key=f)

(4.91 微秒)

不过,对于一个“纯粹”的最小值函数来说,通常没有必要返回多个元素(因为它们都是一样的,而集合中只有一个元素)。但是,如果你是根据某个函数来选择最小值的,通常你会想知道所有满足最小条件的元素。

换句话说,我想要一个简短、清晰且快速的函数,能够根据某种加权函数返回所有的最小元素,最好能同时适用于列表和集合(并且返回的结果与输入的数据类型一致)。

对于列表,我写的最快的代码是:

def allmin(x,f):
    vals = map(f, x)
    minval = min(vals)
    return [x[i] for i,e in enumerate(vals) if e==minval]

6.73 微秒

不过,这距离最佳效果还很远,并且不适用于集合。首先,在映射时,所有的函数值在某个时刻都会在内存中,因此这是确定最小值的最佳时机,而不是再去查看一次。这一点可以通过这个例子看出,尽管与单一最小值的例子相比,除了构建列表外不应该有额外的计算,但这个方法已经慢了50%。我能写出的唯一一个适用于集合的类似代码是:

def allmin(x,f):
    vals = [(f(e), e) for e in x]
    minval = min(vals)[0]
    return {e for val,e in vals if val==minval}

8.44 微秒(列表版本使用列表推导时为7.29微秒)

有没有办法让我在列表上的性能接近于更好的 allmin 版本的性能,最好能接近 min(x,key=f) 的性能呢?

(为了说明和计时,我假设了:

f = lambda x: (x-4.5)**2
x = random.choice([[0,1,2,3,4,5,6,7,8,9,10,11,13],{0,1,2,3,4,5,6,7,8,9,10,11,13}])

2 个回答

1

你现在的做法是对所有元素应用 f() 函数,这需要花费 Θ(n) 的时间,然后再花 Θ(n) 的时间找出这些元素中的最小值,最后还要花 Θ(n) 的时间找出所有等于这个最小值的元素。简单来说,你总共花了 3 x Θ(n) 的时间,其中 n 是输入列表的大小。

理论上,你可以通过在应用 f() 的同时找出最小值来将这个过程缩短到 2 x Θ(n),然后再花 Θ(n) 的时间来获取所有最小值的元素。不过,似乎还有一种更快的方法,你可以在应用 f() 和找最小值时花费 Θ(n) 的时间,而在获取所有最小值时只需花费 O(n) 的时间。(注意,在最坏的情况下,O(n) 和 Θ(n) 是没有区别的。在下面的算法中,最坏情况发生在列表中的所有元素都相同,或者列表是按反向顺序排列的。)

def allmin(x,f):
    minVal = 9999999999999999999999999
    mapped = []
    for a in x:
        mapVal = f(a)
        if mapVal <= minVal:
            minVal = mapVal
            mapped.append((a, mapVal))
    return [a for (a,m) in mapped if m == minVal]

我自己的时间测量显示,对于一个包含从 0 到 100 的整数的列表,相比于你的 allmin() 方法,我的方法大约提高了 20% 的速度。

对于非常大的输入列表,开始时可以先抽样几个元素,这样可以为 minVal 提供一个更好的初始值(而不是简单地初始化为一个非常大的值)。

========================================= 编辑 =========================================

这里有一个版本,可以进一步提高 5~10% 的速度。这种加速的原因是,一旦找到新的最小值,之前存储的所有映射值就可以丢弃。因此,最终获取最小值的 O(n) 时间就不再需要,整个算法只需 1 x Θ(n) 的时间来运行。

def newallmin(x,f):
    minVal = f(x[-1])
    minList = []
    for a in x:
        mapVal = f(a)
        if mapVal > minVal:
            continue
        if mapVal < minVal:
            minVal = mapVal
            minList = [a]
        else: # mapVal == minVal
            minList.append(a)
    return minList

我一直在对一个大小为 10,000,000 的列表进行时间测量,所有元素的范围是从 0 到 100。

2

如果你不知道最小值的数量,那么一个简单的方法就是在遍历过程中,保持一个记录当前最小值的列表,记录到目前为止看到的最小权重:

def minimal(iterable, func):
    'Return a list of minimal values according to a weighting function'
    it = iter(iterable)
    try:
        x = next(it)
    except StopIteration:
        return []
    lowest_values = [x]
    lowest_weight = func(x)
    for x in it:
        weight = func(x)
        if weight == lowest_weight:
            lowest_values.append(x)
        elif weight < lowest_weight:
            lowest_values = []
            lowest_weight = weight
    return lowest_values

下面是这个方法的实际应用:

>>> s = {'abc', 'defg', 'hij', 'kl', 'mno', 'qr', 'stuv', 'wx', 'yz'}
>>> minimal(s, len)
['qr', 'kl', 'yz', 'wx']

另外,如果你提前知道有多少个最小值,可以使用heapq.nsmallest这个函数,它可以直接高效地解决这个问题。对于从个值中找出k个最小值,它会调用你的权重函数次,并且使用的内存与k成正比(也就是说,它非常节省缓存):

>>> from heapq import nsmallest
>>> s = {'abc', 'defg', 'hij', 'kl', 'mno', 'qr', 'stuv', 'wx', 'yz'}
>>> nsmallest(4, s, key=len)
['qr', 'kl', 'yz', 'wx']

撰写回答