使用itertools.groupby的NumPy分组性能
我有很多很大的整数列表(超过3500万),这些列表里会有重复的数字。我需要统计每个整数在列表中出现的次数。下面的代码可以实现这个功能,但感觉速度有点慢。有没有人能用Python,最好是用NumPy,来优化一下这个过程?
def group():
import numpy as np
from itertools import groupby
values = np.array(np.random.randint(0,1<<32, size=35000000), dtype='u4')
values.sort()
groups = ((k, len(list(g))) for k,g in groupby(values))
index = np.fromiter(groups, dtype='u4,u2')
if __name__=='__main__':
from timeit import Timer
t = Timer("group()","from __main__ import group")
print t.timeit(number=1)
这段代码返回:
$ python bench.py
111.377498865
根据大家的回复:
def group_original():
import numpy as np
from itertools import groupby
values = np.array(np.random.randint(0, 1<<32, size=35000000), dtype='u4')
values.sort()
groups = ((k, len(list(g))) for k,g in groupby(values))
index = np.fromiter(groups, dtype='u4,u2')
def group_gnibbler():
import numpy as np
from itertools import groupby
values = np.array(np.random.randint(0, 1<<32, size=35000000), dtype='u4')
values.sort()
groups = ((k,sum(1 for i in g)) for k,g in groupby(values))
index = np.fromiter(groups, dtype='u4,u2')
def group_christophe():
import numpy as np
values = np.array(np.random.randint(0, 1<<32, size=35000000), dtype='u4')
values.sort()
counts=values.searchsorted(values, side='right') - values.searchsorted(values, side='left')
index = np.zeros(len(values), dtype='u4,u2')
index['f0'] = values
index['f1'] = counts
# Erroneous result!
def group_paul():
import numpy as np
values = np.array(np.random.randint(0, 1<<32, size=35000000), dtype='u4')
values.sort()
diff = np.concatenate(([1], np.diff(values)))
idx = np.concatenate((np.where(diff)[0], [len(values)]))
index = np.empty(len(idx)-1, dtype='u4,u2')
index['f0'] = values[idx[:-1]]
index['f1'] = np.diff(idx)
if __name__=='__main__':
from timeit import Timer
timings=[
("group_original", "Original"),
("group_gnibbler", "Gnibbler"),
("group_christophe", "Christophe"),
("group_paul", "Paul"),
]
for method,title in timings:
t = Timer("%s()"%method,"from __main__ import %s"%method)
print "%s: %s secs"%(title, t.timeit(number=1))
这段代码返回:
$ python bench.py
Original: 113.385262966 secs
Gnibbler: 71.7464978695 secs
Christophe: 27.1690568924 secs
Paul: 9.06268405914 secs
不过,Christophe目前给出的结果是不正确的。
10 个回答
根据请求,这里有一个Cython版本的代码。我对数组进行了两次处理。第一次是找出有多少个独特的元素,这样我就可以为这些独特的值和它们的计数准备合适大小的数组。
import numpy as np
cimport numpy as np
cimport cython
@cython.boundscheck(False)
def dogroup():
cdef unsigned long tot = 1
cdef np.ndarray[np.uint32_t, ndim=1] values = np.array(np.random.randint(35000000,size=35000000),dtype=np.uint32)
cdef unsigned long i, ind, lastval
values.sort()
for i in xrange(1,len(values)):
if values[i] != values[i-1]:
tot += 1
cdef np.ndarray[np.uint32_t, ndim=1] vals = np.empty(tot,dtype=np.uint32)
cdef np.ndarray[np.uint32_t, ndim=1] count = np.empty(tot,dtype=np.uint32)
vals[0] = values[0]
ind = 1
lastval = 0
for i in xrange(1,len(values)):
if values[i] != values[i-1]:
vals[ind] = values[i]
count[ind-1] = i - lastval
lastval = i
ind += 1
count[ind-1] = len(values) - lastval
在这个过程中,排序实际上是耗时最多的。根据我代码中给出的值数组,排序花了4.75秒,而真正找出独特值和计数只花了0.67秒。使用保罗的代码的纯Numpy版本(但值数组的形式是一样的),在我评论中建议的修复后,找出独特值和计数花了1.9秒(当然,排序的时间还是一样的)。
排序占用大部分时间是有道理的,因为排序的复杂度是O(N log N),而计数的复杂度是O(N)。你可以稍微加快排序的速度,超过Numpy的速度(如果我没记错的话,Numpy使用的是C语言的qsort),但你必须非常了解这个过程,可能也不值得。此外,我的Cython代码可能还有一些方法可以稍微加快速度,但可能也不值得去追求。
自从保罗的回答被接受以来,已经过去了超过5年。有趣的是,sort()
仍然是这个解决方案中的瓶颈。
Line # Hits Time Per Hit % Time Line Contents
==============================================================
3 @profile
4 def group_paul():
5 1 99040 99040.0 2.4 import numpy as np
6 1 305651 305651.0 7.4 values = np.array(np.random.randint(0, 2**32,size=35000000),dtype='u4')
7 1 2928204 2928204.0 71.3 values.sort()
8 1 78268 78268.0 1.9 diff = np.concatenate(([1],np.diff(values)))
9 1 215774 215774.0 5.3 idx = np.concatenate((np.where(diff)[0],[len(values)]))
10 1 95 95.0 0.0 index = np.empty(len(idx)-1,dtype='u4,u2')
11 1 386673 386673.0 9.4 index['f0'] = values[idx[:-1]]
12 1 91492 91492.0 2.2 index['f1'] = np.diff(idx)
在我的机器上,接受的解决方案运行需要4.0秒,而使用基数排序后,这个时间降到了1.7秒。
仅仅通过切换到基数排序,我的整体速度提升了2.35倍。 在这种情况下,基数排序比快速排序快了超过4倍。
可以查看这个链接 如何比快速排序更快地对整数数组进行排序?,这个问题是受到你提问的启发。
在性能分析时,我使用了 line_profiler 和 kernprof(@profile
就是来自这里)。
我这样做后,效果提升了三倍:
def group():
import numpy as np
values = np.array(np.random.randint(0, 3298, size=35000000), dtype='u4')
values.sort()
dif = np.ones(values.shape, values.dtype)
dif[1:] = np.diff(values)
idx = np.where(dif>0)
vals = values[idx]
count = np.diff(idx)