将Numpy数组作为查找表

4 投票
2 回答
9559 浏览
提问于 2025-04-17 23:47

相关但不同,我认为:

(1) numpy:数组中唯一值的最高效频率计数

(2) 将Numpy数组用作查找表

设置:

import numpy as np
from scipy.stats import itemfreq

x = np.array([1,  1,  1,  2,  25000, 2,  2,  5,  1,  1])
fq = itemfreq(x)
fq.astype(int)
array([[    1,     5],
       [    2,     3],
       [    5,     1],
       [25000,     1]])

现在,我想把fq当作查找表来用,做这个:

res = magic_lookup_function(fq, x)
res
    array([5, 5, 5, 3, 1, 3, 3, 1, 5, 5])

正如(1)和(2)中提到的,我可以把fq转换成一个python字典,然后从那里查找,再转换回np.array。但是有没有更简单、更快、纯粹用numpy的方法呢?

更新:另外,正如(2)中提到的,我可以使用bincount,但我担心如果我的索引很大,比如大约250,000,这样做可能效率不高。

谢谢!

更新的解决方案

正如@Jaime在下面指出的,np.unique对数组进行排序,最好的情况下是O(n log n)的时间复杂度。所以我想知道,itemfreq在内部是怎么处理的?结果发现itemfreq也会对数组进行排序,我假设这也是O(n log n):

In [875]: itemfreq??

def itemfreq(a):
... ... ...
    scores = _support.unique(a)
    scores = np.sort(scores)

这是一个timeit的例子

In [895]: import timeit

In [962]: timeit.timeit('fq = itemfreq(x)', setup='import numpy; from scipy.stats import itemfreq; x = numpy.array([ 1,  1,  1,  2, 250000,  2,  2,  5,  1,  1])', number=1000)
Out[962]: 0.3219749927520752

但似乎没有必要对数组进行排序。如果我们用纯python来做,会发生什么呢。

In [963]: def test(arr):
   .....:     fd = {}
   .....:     for i in arr:
   .....:         fd[i] = fd.get(i,0) + 1
   .....:     return numpy.array([fd[j] for j in arr])

In [967]: timeit.timeit('test(x)', setup='import numpy; from __main__ import test; x = numpy.array([ 1,  1,  1,  2, 250000,  2,  2,  5,  1,  1])', number=1000)
Out[967]: 0.028257131576538086

哇,快了10倍!

(至少在这种情况下,数组不太长,但可能包含大值。)

而且,正如我所怀疑的,使用np.bincount处理大值时效率不高:

In [970]: def test2(arr):
    bc = np.bincount(arr)
    return bc[arr]

In [971]: timeit.timeit('test2(x)', setup='import numpy; from __main__ import test2; x = numpy.array([ 1,  1,  1,  2, 250000,  2,  2,  5,  1,  1])', number=1000)
Out[971]: 0.0975029468536377

2 个回答

0

因为你的查找表不仅仅是普通的查找表,而是一个频率列表,所以你可能想考虑以下选项:

>>> x = np.array([1,  1,  1,  2,  25, 2,  2,  5,  1,  1])
>>> x_unq, x_idx = np.unique(x, return_inverse=True)
>>> np.take(np.bincount(x_idx), x_idx)
array([5, 5, 5, 3, 1, 3, 3, 1, 5, 5], dtype=int64)

即使你的查找表比较复杂,也就是说:

>>> lut = np.array([[ 1, 10],
...                 [ 2,  9],
...                 [ 5,  8],
...                 [25,  7]])

如果你可以使用 np.unique(这个函数会对数组进行排序,所以时间复杂度是 n log n)并且设置 return_index,那么你可以用小的连续整数作为索引,这样通常会让事情变得更简单。例如,使用 np.searchsorted,你可以这样做:

>>> np.take(lut[:, 1], np.take(np.searchsorted(lut[:, 0], x_unq), x_idx))
array([10, 10, 10,  9,  7,  9,  9,  8, 10, 10])
4

你可以使用 numpy.searchsorted 这个功能:

def get_index(arr, val):                                                                
    index = np.searchsorted(arr, val)                                                            
    if arr[index] == val:                                                                        
        return index                                                                             

In [20]: arr = fq[:,:1].ravel()                                                                  

In [21]: arr
Out[21]: array([  1.,   2.,   5.,  25.])

In [22]: get_index(arr, 25)                                                                      
Out[22]: 3

In [23]: get_index(arr, 2)                                                                       
Out[23]: 1

In [24]: get_index(arr, 4)    #returns `None` for  item not found.  

撰写回答