在Python中快速获取加权数组的随机索引方法
我经常需要从一个数组或列表中随机选择一个索引,但这些索引的选择概率不是均匀的,而是根据某些正权重来分配的。有没有什么快速的方法可以做到这一点?我知道可以把权重作为可选参数 p
传给 numpy.random.choice
,但是这个函数似乎比较慢,而且构建一个 arange
来传递权重也不是个好主意。权重的总和可以是任意正数,并不一定是1,这让生成一个在 (0,1] 之间的随机数,然后减去权重直到结果为0或更小的方法变得不可行。
虽然网上有一些关于如何简单实现类似功能的答案(大多数是关于获取对应元素,而不是数组索引),比如 加权选择简单短小,但我在寻找一个快速的解决方案,因为这个合适的函数会被频繁调用。我的权重变化很快,所以像别名掩码这样的构建开销(详细介绍可以在 http://www.keithschwarz.com/darts-dice-coins/ 找到)也应该算在计算时间里。
2 个回答
从Python 3.6开始,标准库里的random.choices
函数可以接受weights
或者cum_weights
这两个参数。
累积求和和二分查找
在一般情况下,建议先计算权重的累积和,然后使用二分查找模块中的 bisect 来在结果的排序数组中找到一个随机点。
def weighted_choice(weights):
cs = numpy.cumsum(weights)
return bisect.bisect(cs, numpy.random.random() * cs[-1])
如果速度是一个考虑因素,下面会有更详细的分析。
注意:如果数组不是一维的,可以使用 numpy.unravel_index
将一维索引转换为多维索引,具体可以参考 这个链接。
实验分析
使用 numpy
内置函数有四种或多种明显的解决方案。通过 timeit
比较它们的性能,得到了以下结果:
import timeit
weighted_choice_functions = [
"""import numpy
wc = lambda weights: numpy.random.choice(
range(len(weights)),
p=weights/weights.sum())
""",
"""import numpy
# Adapted from https://stackoverflow.com/a/19760118/1274613
def wc(weights):
cs = numpy.cumsum(weights)
return cs.searchsorted(numpy.random.random() * cs[-1], 'right')
""",
"""import numpy, bisect
# Using bisect mentioned in https://stackoverflow.com/a/13052108/1274613
def wc(weights):
cs = numpy.cumsum(weights)
return bisect.bisect(cs, numpy.random.random() * cs[-1])
""",
"""import numpy
wc = lambda weights: numpy.random.multinomial(
1,
weights/weights.sum()).argmax()
"""]
for setup in weighted_choice_functions:
for ps in ["numpy.ones(40)",
"numpy.arange(10)",
"numpy.arange(200)",
"numpy.arange(199,-1,-1)",
"numpy.arange(4000)"]:
timeit.timeit("wc(%s)"%ps, setup=setup)
print()
结果输出为
178.45797914802097
161.72161589498864
223.53492237901082
224.80936180002755
1901.6298267539823
15.197789980040397
19.985687876993325
20.795070077001583
20.919113760988694
41.6509403079981
14.240949985047337
17.335801470966544
19.433710905024782
19.52205040602712
35.60536142199999
26.6195822560112
20.501282756973524
31.271995796996634
27.20013752405066
243.09768892999273
这意味着 numpy.random.choice
的速度出乎意料地慢,甚至专门为此设计的 searchsorted
方法也比简单的 bisect
变体要慢。(这些结果是使用 Python 3.3.5 和 numpy 1.8.1 得到的,其他版本可能会有所不同。)基于 numpy.random.multinomial
的函数在处理大权重时效率不如基于累积求和的方法。可以推测,argmax 需要遍历整个数组并在每一步进行比较,这对性能有显著影响,这也可以从权重列表递增和递减之间四秒的差异中看出。