使用itertools.groupby的NumPy分组性能

27 投票
10 回答
24943 浏览
提问于 2025-04-16 09:44

我有很多很大的整数列表(超过3500万),这些列表里会有重复的数字。我需要统计每个整数在列表中出现的次数。下面的代码可以实现这个功能,但感觉速度有点慢。有没有人能用Python,最好是用NumPy,来优化一下这个过程?

def group():
    import numpy as np
    from itertools import groupby
    values = np.array(np.random.randint(0,1<<32, size=35000000), dtype='u4')
    values.sort()
    groups = ((k, len(list(g))) for k,g in groupby(values))
    index = np.fromiter(groups, dtype='u4,u2')

if __name__=='__main__':
    from timeit import Timer
    t = Timer("group()","from __main__ import group")
    print t.timeit(number=1)

这段代码返回:

$ python bench.py
111.377498865

根据大家的回复:

def group_original():
    import numpy as np
    from itertools import groupby
    values = np.array(np.random.randint(0, 1<<32, size=35000000), dtype='u4')
    values.sort()
    groups = ((k, len(list(g))) for k,g in groupby(values))
    index = np.fromiter(groups, dtype='u4,u2')

def group_gnibbler():
    import numpy as np
    from itertools import groupby
    values = np.array(np.random.randint(0, 1<<32, size=35000000), dtype='u4')
    values.sort()
    groups = ((k,sum(1 for i in g)) for k,g in groupby(values))
    index = np.fromiter(groups, dtype='u4,u2')

def group_christophe():
    import numpy as np
    values = np.array(np.random.randint(0, 1<<32, size=35000000), dtype='u4')
    values.sort()
    counts=values.searchsorted(values, side='right') - values.searchsorted(values, side='left')
    index = np.zeros(len(values), dtype='u4,u2')
    index['f0'] = values
    index['f1'] = counts
    # Erroneous result!

def group_paul():
    import numpy as np
    values = np.array(np.random.randint(0, 1<<32, size=35000000), dtype='u4')
    values.sort()
    diff = np.concatenate(([1], np.diff(values)))
    idx = np.concatenate((np.where(diff)[0], [len(values)]))
    index = np.empty(len(idx)-1, dtype='u4,u2')
    index['f0'] = values[idx[:-1]]
    index['f1'] = np.diff(idx)

if __name__=='__main__':
    from timeit import Timer
    timings=[
                ("group_original", "Original"),
                ("group_gnibbler", "Gnibbler"),
                ("group_christophe", "Christophe"),
                ("group_paul", "Paul"),
            ]
    for method,title in timings:
        t = Timer("%s()"%method,"from __main__ import %s"%method)
        print "%s: %s secs"%(title, t.timeit(number=1))

这段代码返回:

$ python bench.py
Original: 113.385262966 secs
Gnibbler: 71.7464978695 secs
Christophe: 27.1690568924 secs
Paul: 9.06268405914 secs

不过,Christophe目前给出的结果是不正确的。

10 个回答

5

根据请求,这里有一个Cython版本的代码。我对数组进行了两次处理。第一次是找出有多少个独特的元素,这样我就可以为这些独特的值和它们的计数准备合适大小的数组。

import numpy as np
cimport numpy as np
cimport cython

@cython.boundscheck(False)
def dogroup():
    cdef unsigned long tot = 1
    cdef np.ndarray[np.uint32_t, ndim=1] values = np.array(np.random.randint(35000000,size=35000000),dtype=np.uint32)
    cdef unsigned long i, ind, lastval
    values.sort()
    for i in xrange(1,len(values)):
        if values[i] != values[i-1]:
            tot += 1
    cdef np.ndarray[np.uint32_t, ndim=1] vals = np.empty(tot,dtype=np.uint32)
    cdef np.ndarray[np.uint32_t, ndim=1] count = np.empty(tot,dtype=np.uint32)
    vals[0] = values[0]
    ind = 1
    lastval = 0
    for i in xrange(1,len(values)):
        if values[i] != values[i-1]:
            vals[ind] = values[i]
            count[ind-1] = i - lastval
            lastval = i
            ind += 1
    count[ind-1] = len(values) - lastval

在这个过程中,排序实际上是耗时最多的。根据我代码中给出的值数组,排序花了4.75秒,而真正找出独特值和计数只花了0.67秒。使用保罗的代码的纯Numpy版本(但值数组的形式是一样的),在我评论中建议的修复后,找出独特值和计数花了1.9秒(当然,排序的时间还是一样的)。

排序占用大部分时间是有道理的,因为排序的复杂度是O(N log N),而计数的复杂度是O(N)。你可以稍微加快排序的速度,超过Numpy的速度(如果我没记错的话,Numpy使用的是C语言的qsort),但你必须非常了解这个过程,可能也不值得。此外,我的Cython代码可能还有一些方法可以稍微加快速度,但可能也不值得去追求。

14

自从保罗的回答被接受以来,已经过去了超过5年。有趣的是,sort() 仍然是这个解决方案中的瓶颈。

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     3                                           @profile
     4                                           def group_paul():
     5         1        99040  99040.0      2.4      import numpy as np
     6         1       305651 305651.0      7.4      values = np.array(np.random.randint(0, 2**32,size=35000000),dtype='u4')
     7         1      2928204 2928204.0    71.3      values.sort()
     8         1        78268  78268.0      1.9      diff = np.concatenate(([1],np.diff(values)))
     9         1       215774 215774.0      5.3      idx = np.concatenate((np.where(diff)[0],[len(values)]))
    10         1           95     95.0      0.0      index = np.empty(len(idx)-1,dtype='u4,u2')
    11         1       386673 386673.0      9.4      index['f0'] = values[idx[:-1]]
    12         1        91492  91492.0      2.2      index['f1'] = np.diff(idx)

在我的机器上,接受的解决方案运行需要4.0秒,而使用基数排序后,这个时间降到了1.7秒。

仅仅通过切换到基数排序,我的整体速度提升了2.35倍。 在这种情况下,基数排序比快速排序快了超过4倍。

可以查看这个链接 如何比快速排序更快地对整数数组进行排序?,这个问题是受到你提问的启发。


在性能分析时,我使用了 line_profiler 和 kernprof@profile 就是来自这里)。

32

我这样做后,效果提升了三倍:

def group():
    import numpy as np
    values = np.array(np.random.randint(0, 3298, size=35000000), dtype='u4')
    values.sort()
    dif = np.ones(values.shape, values.dtype)
    dif[1:] = np.diff(values)
    idx = np.where(dif>0)
    vals = values[idx]
    count = np.diff(idx)

撰写回答