在Python中快速获取加权数组的随机索引方法

3 投票
2 回答
1768 浏览
提问于 2025-04-18 09:16

我经常需要从一个数组或列表中随机选择一个索引,但这些索引的选择概率不是均匀的,而是根据某些正权重来分配的。有没有什么快速的方法可以做到这一点?我知道可以把权重作为可选参数 p 传给 numpy.random.choice,但是这个函数似乎比较慢,而且构建一个 arange 来传递权重也不是个好主意。权重的总和可以是任意正数,并不一定是1,这让生成一个在 (0,1] 之间的随机数,然后减去权重直到结果为0或更小的方法变得不可行。

虽然网上有一些关于如何简单实现类似功能的答案(大多数是关于获取对应元素,而不是数组索引),比如 加权选择简单短小,但我在寻找一个快速的解决方案,因为这个合适的函数会被频繁调用。我的权重变化很快,所以像别名掩码这样的构建开销(详细介绍可以在 http://www.keithschwarz.com/darts-dice-coins/ 找到)也应该算在计算时间里。

2 个回答

0

从Python 3.6开始,标准库里的random.choices函数可以接受weights或者cum_weights这两个参数。

5

累积求和和二分查找

在一般情况下,建议先计算权重的累积和,然后使用二分查找模块中的 bisect 来在结果的排序数组中找到一个随机点。

def weighted_choice(weights):
    cs = numpy.cumsum(weights)
    return bisect.bisect(cs, numpy.random.random() * cs[-1])

如果速度是一个考虑因素,下面会有更详细的分析。

注意:如果数组不是一维的,可以使用 numpy.unravel_index 将一维索引转换为多维索引,具体可以参考 这个链接

实验分析

使用 numpy 内置函数有四种或多种明显的解决方案。通过 timeit 比较它们的性能,得到了以下结果:

import timeit

weighted_choice_functions = [
"""import numpy
wc = lambda weights: numpy.random.choice(
    range(len(weights)),
    p=weights/weights.sum())
""",
"""import numpy
# Adapted from https://stackoverflow.com/a/19760118/1274613
def wc(weights):
    cs = numpy.cumsum(weights)
    return cs.searchsorted(numpy.random.random() * cs[-1], 'right')
""",
"""import numpy, bisect
# Using bisect mentioned in https://stackoverflow.com/a/13052108/1274613
def wc(weights):
    cs = numpy.cumsum(weights)
    return bisect.bisect(cs, numpy.random.random() * cs[-1])
""",
"""import numpy
wc = lambda weights: numpy.random.multinomial(
    1,
    weights/weights.sum()).argmax()
"""]

for setup in weighted_choice_functions:
    for ps in ["numpy.ones(40)",
               "numpy.arange(10)",
               "numpy.arange(200)",
               "numpy.arange(199,-1,-1)",
               "numpy.arange(4000)"]:
        timeit.timeit("wc(%s)"%ps, setup=setup)
    print()

结果输出为

178.45797914802097
161.72161589498864
223.53492237901082
224.80936180002755
1901.6298267539823

15.197789980040397
19.985687876993325
20.795070077001583
20.919113760988694
41.6509403079981

14.240949985047337
17.335801470966544
19.433710905024782
19.52205040602712
35.60536142199999

26.6195822560112
20.501282756973524
31.271995796996634
27.20013752405066
243.09768892999273

这意味着 numpy.random.choice 的速度出乎意料地慢,甚至专门为此设计的 searchsorted 方法也比简单的 bisect 变体要慢。(这些结果是使用 Python 3.3.5 和 numpy 1.8.1 得到的,其他版本可能会有所不同。)基于 numpy.random.multinomial 的函数在处理大权重时效率不如基于累积求和的方法。可以推测,argmax 需要遍历整个数组并在每一步进行比较,这对性能有显著影响,这也可以从权重列表递增和递减之间四秒的差异中看出。

撰写回答