在numpy和python中快速去除重复项
有没有什么快速的方法可以在numpy中获取唯一的元素?我有一段类似这样的代码(最后一行)
tab = numpy.arange(100000000)
indices1 = numpy.random.permutation(10000)
indices2 = indices1.copy()
indices3 = indices1.copy()
indices4 = indices1.copy()
result = numpy.unique(numpy.array([tab[indices1], tab[indices2], tab[indices3], tab[indices4]]))
这只是一个例子,在我的情况中,indices1, indices2,...,indices4
包含不同的索引集合,并且大小各异。最后一行代码执行了很多次,我发现这实际上是我代码中的瓶颈(具体来说是{numpy.core.multiarray.arange}
)。另外,顺序并不重要,索引数组中的元素是int32
类型。我在考虑使用哈希表,把元素值作为键,并尝试了:
seq = itertools.chain(tab[indices1].flatten(), tab[indices2].flatten(), tab[indices3].flatten(), tab[indices4].flatten())
myset = {}
map(myset.__setitem__, seq, [])
result = numpy.array(myset.keys())
但效果更糟。
有没有什么方法可以加快这个速度?我猜性能下降是因为“花哨索引”会复制数组,但我只需要读取结果元素(我并不修改任何东西)。
2 个回答
以下内容其实有部分不准确(见附注):
获取所有子数组中唯一元素的这个方法非常快:
seq = itertools.chain(tab[indices1].flat, tab[indices2].flat, tab[indices3].flat, tab[indices4].flat)
result = set(seq)
注意,这里使用的是 flat
(它返回一个迭代器),而不是 flatten()
(它返回一个完整的数组),而且可以直接调用 set()
(而不是像你第二种方法那样使用 map()
和字典)。
以下是一些时间测试结果(在 IPython 终端中获得):
>>> %timeit result = numpy.unique(numpy.array([tab[indices1], tab[indices2], tab[indices3], tab[indices4]]))
100 loops, best of 3: 8.04 ms per loop
>>> seq = itertools.chain(tab[indices1].flat, tab[indices2].flat, tab[indices3].flat, tab[indices4].flat)
>>> %timeit set(seq)
1000000 loops, best of 3: 223 ns per loop
因此,在这个例子中,set/flat 方法快了 40 倍。
附注:set(seq)
的时间其实并不具有代表性。实际上,第一次循环的时间测试清空了 seq
迭代器,后续的 set()
评估返回的是一个空集合!正确的时间测试如下:
>>> %timeit set(itertools.chain(tab[indices1].flat, tab[indices2].flat, tab[indices3].flat, tab[indices4].flat))
100 loops, best of 3: 9.12 ms per loop
这显示出 set/flat 方法实际上并没有更快。
附加附注:这是对 mtrw 建议的一个(不成功的)探索;提前找到唯一的索引可能是个好主意,但我找不到比上述方法更快的实现方式:
>>> %timeit set(indices1).union(indices2).union(indices3).union(indices4)
100 loops, best of 3: 11.9 ms per loop
>>> %timeit set(itertools.chain(indices1.flat, indices2.flat, indices3.flat, indices4.flat))
100 loops, best of 3: 10.8 ms per loop
因此,找到所有不同索引的集合本身是相当慢的。
再附注:numpy.unique(<连接后的索引数组>)
实际上比 set(<连接后的索引数组>)
快 2-3 倍。这是 Bago 答案中提到的加速的关键(unique(concatenate((…)))
)。原因可能是让 NumPy 自己处理数组通常比用纯 Python 的 set
来处理 NumPy 数组要快。
总结:因此,这个回答只记录了一些失败的尝试,不应该完全遵循,同时也提到了一些关于用迭代器进行时间测试的有用备注……
抱歉,我不太明白你的问题,但我会尽力帮助你。
首先,{numpy.core.multiarray.arange} 是 numpy.arange,而不是所谓的花式索引。不幸的是,花式索引在性能分析工具中不会单独显示。如果你在循环中调用 np.arange,建议你看看能不能把它移到循环外面。
In [27]: prun tab[tab]
2 function calls in 1.551 CPU seconds
Ordered by: internal time
ncalls tottime percall cumtime percall filename:lineno(function)
1 1.551 1.551 1.551 1.551 <string>:1(<module>)
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
In [28]: prun numpy.arange(10000000)
3 function calls in 0.051 CPU seconds
Ordered by: internal time
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.047 0.047 0.047 0.047 {numpy.core.multiarray.arange}
1 0.003 0.003 0.051 0.051 <string>:1(<module>)
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
其次,我假设你的代码中 tab
不是 np.arange(a, b)
,因为如果是的话,tab[index] == index + a
,但我猜这只是你举例用的。
第三,np.concatenate 的速度大约是 np.array 的十倍。
In [47]: timeit numpy.array([tab[indices1], tab[indices2], tab[indices3], tab[indices4]])
100 loops, best of 3: 5.11 ms per loop
In [48]: timeit numpy.concatenate([tab[indices1], tab[indices2], tab[indices3], tab[indices4]])
1000 loops, best of 3: 544 us per loop
(另外,np.concatenate 会返回一个形状为 (4*n,) 的数组,而 np.array 返回的是形状为 (4, n) 的数组,其中 n 是 indices[1-4] 的长度。后者只有在 indices1-4 的长度都相同的情况下才能工作。)
最后,如果你能这样做,还可以节省更多时间:
indices = np.unique(np.concatenate((indices1, indices2, indices3, indices4)))
result = tab[indices]
按照这个顺序做会更快,因为你减少了需要在 tab 中查找的索引数量,但前提是你知道 tab 中的元素是唯一的(否则即使索引是唯一的,结果中也可能会出现重复)。
希望这些对你有帮助。