将Numpy数组作为查找表
相关但不同,我认为:
(2) 将Numpy数组用作查找表
设置:
import numpy as np
from scipy.stats import itemfreq
x = np.array([1, 1, 1, 2, 25000, 2, 2, 5, 1, 1])
fq = itemfreq(x)
fq.astype(int)
array([[ 1, 5],
[ 2, 3],
[ 5, 1],
[25000, 1]])
现在,我想把fq当作查找表来用,做这个:
res = magic_lookup_function(fq, x)
res
array([5, 5, 5, 3, 1, 3, 3, 1, 5, 5])
正如(1)和(2)中提到的,我可以把fq转换成一个python字典,然后从那里查找,再转换回np.array。但是有没有更简单、更快、纯粹用numpy的方法呢?
更新:另外,正如(2)中提到的,我可以使用bincount,但我担心如果我的索引很大,比如大约250,000,这样做可能效率不高。
谢谢!
更新的解决方案
正如@Jaime在下面指出的,np.unique对数组进行排序,最好的情况下是O(n log n)的时间复杂度。所以我想知道,itemfreq
在内部是怎么处理的?结果发现itemfreq
也会对数组进行排序,我假设这也是O(n log n):
In [875]: itemfreq??
def itemfreq(a):
... ... ...
scores = _support.unique(a)
scores = np.sort(scores)
这是一个timeit的例子
In [895]: import timeit
In [962]: timeit.timeit('fq = itemfreq(x)', setup='import numpy; from scipy.stats import itemfreq; x = numpy.array([ 1, 1, 1, 2, 250000, 2, 2, 5, 1, 1])', number=1000)
Out[962]: 0.3219749927520752
但似乎没有必要对数组进行排序。如果我们用纯python来做,会发生什么呢。
In [963]: def test(arr):
.....: fd = {}
.....: for i in arr:
.....: fd[i] = fd.get(i,0) + 1
.....: return numpy.array([fd[j] for j in arr])
In [967]: timeit.timeit('test(x)', setup='import numpy; from __main__ import test; x = numpy.array([ 1, 1, 1, 2, 250000, 2, 2, 5, 1, 1])', number=1000)
Out[967]: 0.028257131576538086
哇,快了10倍!
(至少在这种情况下,数组不太长,但可能包含大值。)
而且,正如我所怀疑的,使用np.bincount
处理大值时效率不高:
In [970]: def test2(arr):
bc = np.bincount(arr)
return bc[arr]
In [971]: timeit.timeit('test2(x)', setup='import numpy; from __main__ import test2; x = numpy.array([ 1, 1, 1, 2, 250000, 2, 2, 5, 1, 1])', number=1000)
Out[971]: 0.0975029468536377
2 个回答
0
因为你的查找表不仅仅是普通的查找表,而是一个频率列表,所以你可能想考虑以下选项:
>>> x = np.array([1, 1, 1, 2, 25, 2, 2, 5, 1, 1])
>>> x_unq, x_idx = np.unique(x, return_inverse=True)
>>> np.take(np.bincount(x_idx), x_idx)
array([5, 5, 5, 3, 1, 3, 3, 1, 5, 5], dtype=int64)
即使你的查找表比较复杂,也就是说:
>>> lut = np.array([[ 1, 10],
... [ 2, 9],
... [ 5, 8],
... [25, 7]])
如果你可以使用 np.unique
(这个函数会对数组进行排序,所以时间复杂度是 n log n)并且设置 return_index
,那么你可以用小的连续整数作为索引,这样通常会让事情变得更简单。例如,使用 np.searchsorted
,你可以这样做:
>>> np.take(lut[:, 1], np.take(np.searchsorted(lut[:, 0], x_unq), x_idx))
array([10, 10, 10, 9, 7, 9, 9, 8, 10, 10])
4
你可以使用 numpy.searchsorted
这个功能:
def get_index(arr, val):
index = np.searchsorted(arr, val)
if arr[index] == val:
return index
In [20]: arr = fq[:,:1].ravel()
In [21]: arr
Out[21]: array([ 1., 2., 5., 25.])
In [22]: get_index(arr, 25)
Out[22]: 3
In [23]: get_index(arr, 2)
Out[23]: 1
In [24]: get_index(arr, 4) #returns `None` for item not found.