Python中快速检查范围的方法
我有很多范围,比如 [(1, 1000), (5000, 5678), ... ]
这样的格式。我想找出最快的方法来检查一个数字是否在这些范围内。这些范围的数字很大,所以不能简单地把所有数字放在一个 set
里。
最简单的解决方案是:
ranges = [(1,5), (10,20), (40,50)] # The real code has a few dozen ranges
nums = range(1000000)
%timeit [n for n in nums if any([r[0] <= n <= r[1] for r in ranges])]
# 1 loops, best of 3: 5.31 s per loop
Banyan 的速度稍微快一些:
import banyan
banyan_ranges = banyan.SortedSet(updator=banyan.OverlappingIntervalsUpdator)
for r in ranges:
banyan_ranges.add(r)
%timeit [n for n in nums if len(banyan_ranges.overlap_point(n))>0]
# 1 loops, best of 3: 452 ms per loop
虽然范围只有几十个,但要检查的次数却有上百万次。那么,做这些检查的最快方法是什么呢?
(注意:这个问题和 Python: 高效检查整数是否在多个范围内 有点相似,但没有 Django 相关的限制,主要关注速度问题)
3 个回答
试着用二分查找来代替线性查找。这样做的时间复杂度是“Log(n)”。具体可以看下面的代码:
list = []
for num in nums:
start = 0
end = len(ranges)-1
if ranges[start][0] <= num <= ranges[start][1]:
list.append(num)
elif ranges[end][0] <= num <= ranges[end][1]:
list.append(num):
else:
while end-start>1:
mid = int(end+start/2)
if ranges[mid][0] <= num <= ranges[mid][1]:
list.append(num)
break
elif num < ranges[mid][0]:
end = mid
else:
start = mid
这是对@ArminRigo评论的一种实现,速度相当快。这个时间测量是基于CPython,而不是PyPy:
exec_code = "def in_range(x):\n"
first_if = True
for r in ranges:
if first_if:
exec_code += " if "
first_if = False
else:
exec_code += " elif "
exec_code += "%d <= x <= %d: return True\n" % (r[0], r[1])
exec_code += " return False"
exec(exec_code)
%timeit [n for n in nums if in_range(n)]
# 10 loops, best of 3: 173 ms per loop
可以尝试的办法:
- 先处理你的范围,使它们不重叠,并把它们表示为半开区间。
- 使用
bisect
模块来进行查找。(不要自己手动实现二分查找!)注意,经过第一步的处理后,你只需要知道bisect
调用的结果是偶数还是奇数。 - 如果可以批量处理查询,可以考虑把输入分组到一个数组中,然后使用
numpy.searchsorted
。
下面是一些代码和时间测试。首先是设置(这里使用 IPython 2.1 和 Python 3.4):
In [1]: ranges = [(1, 5), (10, 20), (40, 50)]
In [2]: nums = list(range(1000000)) # force a list to remove generator overhead
在我的机器上,原始方法的时间测试(但使用生成器表达式而不是列表推导式):
In [3]: %timeit [n for n in nums if any(r[0] <= n <= r[1] for r in ranges)]
1 loops, best of 3: 922 ms per loop
现在我们将范围重新处理为一个边界点的列表;每个在 偶数 索引的边界点是某个范围的入口点,而每个在 奇数 索引的边界点是出口点。注意转换为半开区间,并且我把所有数字放在一个列表里。
In [4]: boundaries = [1, 6, 10, 21, 40, 51]
这样就可以轻松使用 bisect.bisect
得到和之前一样的结果,但速度更快。
In [5]: from bisect import bisect
In [6]: %timeit [n for n in nums if bisect(boundaries, n) % 2]
1 loops, best of 3: 298 ms per loop
最后,根据具体情况,你可能可以利用 NumPy 的 searchsorted
函数。这个函数和 bisect.bisect
类似,但可以一次处理一整组值。例如:
In [7]: import numpy
In [8]: numpy.where(numpy.searchsorted(boundaries, nums, side="right") % 2)[0]
Out[8]:
array([ 1, 2, 3, 4, 5, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 40,
41, 42, 43, 44, 45, 46, 47, 48, 49, 50])
乍一看,这个 %timeit
的结果有点让人失望。
In [9]: %timeit numpy.where(numpy.searchsorted(boundaries, nums, side="right") % 2)[0]
10 loops, best of 3: 159 ms per loop
不过,实际上性能损失主要是在把输入从 Python 列表转换为 NumPy 数组的过程中。让我们先把两个列表转换成数组,然后再试一次:
In [10]: boundaries = numpy.array(boundaries)
In [11]: nums = numpy.array(nums)
In [12]: %timeit numpy.where(numpy.searchsorted(boundaries, nums, side="right") % 2)[0]
10 loops, best of 3: 24.6 ms per loop
比之前的任何方法都要快得多。不过,这有点作弊:我们当然可以预处理 boundaries
使其变成数组,但如果你想测试的值不是自然以数组形式产生的,那么转换的成本就需要考虑了。另一方面,这表明搜索本身的成本可以降低到一个足够小的值,这样就不再是运行时间的主要因素了。
这里还有另一个选择。它再次使用 NumPy,但对每个值进行直接的非懒惰线性搜索。(请原谅 IPython
提示的顺序不对:我稍后添加了这个。:-)
In [29]: numpy.where(numpy.logical_xor.reduce(numpy.greater_equal.outer(boundaries, nums), axis=0))
Out[29]:
(array([ 2, 3, 4, 5, 6, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51]),)
In [30]: %timeit numpy.where(numpy.logical_xor.reduce(numpy.greater_equal.outer(boundaries, nums), axis=0))
10 loops, best of 3: 16.7 ms per loop
对于这些特定的测试数据,这比 searchsorted
快,但时间会随着范围数量线性增长,而 searchsorted
的增长应该是范围数量的对数。注意,它还使用了与 len(boundaries) * len(nums)
成正比的内存。这不一定是个问题:如果你发现内存受限,可以考虑把数组分成更小的块(比如每次处理 10000 个元素),这样性能损失不会太大。
如果这些方法都不合适,我接下来会尝试 Cython 和 NumPy,编写一个搜索函数(输入声明为整数数组),对 boundaries
数组进行简单的线性搜索。我尝试过这个,但没有得到比 bisect.bisect
更好的结果。作为参考,这里是我尝试的 Cython 代码;你可能能做得更好:
cimport cython
cimport numpy as np
@cython.boundscheck(False)
@cython.wraparound(False)
def search(np.ndarray[long, ndim=1] boundaries, long val):
cdef long j, k, n=len(boundaries)
for j in range(n):
if boundaries[j] > val:
return j & 1
return 0
还有时间测试:
In [13]: from my_cython_extension import search
In [14]: %timeit [n for n in nums if search(boundaries, n)]
1 loops, best of 3: 793 ms per loop