对于大型阵列,有没有比np.isin更快的方法?

2024-06-17 12:45:54 发布

您现在位置:Python中文网/ 问答频道 /正文

对于大型数组(n>;1e8),是否有比np.isin更快的方法来检查是否存在相同的元素

我试过几种方法,比如pandasisin,cython,但所有这些都比np.isin花费更多的时间

示例:(测试一维数组的每个元素是否也存在于第二个数组中)

num = int(1e8)
a = np.random.rand(int(num))
b = np.random.rand(int(num))

ref=time.time()
ainb = np.isin(a,b)
print(a[ainb])
print(time.time()-ref,'sec')

>>> [0.23591019 0.46102523]
>>> 65.45570135116577 sec

Tags: 方法gtref元素timenprandom数组
2条回答

我不确定这是否适合您的需要,但是集合的交集(&)看起来更快。缺点是你只会得到元素,而不是它们的索引

import numpy as np
import time

num = int(1e7)
a = np.random.rand(int(num))
b = np.random.rand(int(num))

a_set = set(a)
b_set = set(b)

交叉口的计时设置:

ref=time.time()
print(a_set & b_set)
print(time.time()-ref,'sec')

>> set()
0.6486170291900635 sec

计时np.isin()

ref=time.time()
ainb = np.isin(a, b)
print( a[ainb] )
print(time.time()-ref,'sec')

>> []
4.06556510925293 sec

如果您需要一个下拉列表(针对您的用例),但可能需要更快地替换np.isin(),那么您可以使用Python set()进行检查,并加速Numba中的显式循环:

import numpy as np
import numba as nb


@nb.jit
def is_in_set_nb(a, b):
    shape = a.shape
    a = a.ravel()
    n = len(a)
    result = np.full(n, False)
    set_b = set(b)
    for i in range(n):
        if a[i] in set_b:
            result[i] = True
    return result.reshape(shape)

请注意,有一些(便宜的)额外代码可以使其适用于N-dim阵列,如果您只需要1D,则可以省略这些代码

通过添加进一步的并行化,甚至可以加快速度:

import numpy as np
import numba as nb


@nb.jit(parallel=True)
def is_in_set_pnb(a, b):
    shape = a.shape
    a = a.ravel()
    n = len(a)
    result = np.full(n, False)
    set_b = set(b)
    for i in nb.prange(n):
        if a[i] in set_b:
            result[i] = True
    return result.reshape(shape)

这比np.isin()set()交叉点和is_in_set()无加速的解决方案快得多:

def is_in_set(a, b):
    set_b = set(b)
    return np.array([x in set_b for x in a])

输入大小为1000万个元素:

n = 10 ** 7
k = n // 3
np.random.seed(0)
# note: I used `int`s because I wanted to be able to control the collisions
a = np.random.randint(0, k * n, n)
b = np.random.randint(0, k * n, n)


%timeit ainb = np.isin(a, b); a[ainb]
# 1 loop, best of 3: 3.94 s per loop
%timeit ainb = is_in_set_nb(a, b); a[ainb]
# 1 loop, best of 3: 814 ms per loop
%timeit ainb = is_in_set_pnb(a, b); a[ainb]
# 1 loop, best of 3: 740 ms per loop
%timeit ainb = is_in_set(a, b); a[ainb]
# 1 loop, best of 3: 7.69 s per loop
%timeit set(a).intersection(b)  # not a drop-in replacement
# 1 loop, best of 3: 6.79 s per loop
%timeit set(a) & set(b)  # not a drop-in replacement
# 1 loop, best of 3: 8.98 s per loop

并且有一亿个元素(最后两种方法最终填满了所有内存,因此被省略):

n = 10 ** 8
k = n // 3
np.random.seed(0)
a = np.random.randint(0, k * n, n)
b = np.random.randint(0, k * n, n)


%timeit ainb = np.isin(a, b); a[ainb]
# 1 loop, best of 3: 1min 4s per loop
%timeit ainb = is_in_set_nb(a, b); a[ainb]
# 1 loop, best of 3: 13.1 s per loop
%timeit ainb = is_in_set_pnb(a, b); a[ainb]
# 1 loop, best of 3: 11.4 s per loop
%timeit ainb = is_in_set(a, b); a[ainb]
# 1 loop, best of 3: 2min 5s per loop

为较小的输入添加更多计时,但所有长度组合为ab

funcs = np.isin, is_in_set_nb, is_in_set_pnb
sep = '    '
print(f'({"n=len(a)":>9s},{"m=len(b)":>9s})', end=sep)
for func in funcs:
    print(f'{func.__name__:15s}', end=sep)
print()
I, J = 7, 7
for i in range(I):
    for j in range(J):
        n = 10 ** i
        m = 10 ** j
        a = np.random.randint(0, m * n, n)
        b = np.random.randint(0, m * n, m)
        print(f'({n:9d},{m:9d})', end=sep)
        for func in funcs:
            result = %timeit -q -o func(a, b)
            print(f'{result.best * 1e3:12.3f} ms', end=sep)
        print()
( n=len(a), m=len(b))    isin               is_in_set_nb       is_in_set_pnb      
(        1,        1)           0.011 ms           0.001 ms           0.047 ms    
(        1,       10)           0.048 ms           0.001 ms           0.023 ms    
(        1,      100)           0.050 ms           0.002 ms           0.027 ms    
(        1,     1000)           0.102 ms           0.007 ms           0.041 ms    
(        1,    10000)           0.766 ms           1.028 ms           1.122 ms    
(        1,   100000)           9.717 ms           3.426 ms           3.356 ms    
(        1,  1000000)         105.154 ms          43.642 ms          40.734 ms    
(       10,        1)           0.010 ms           0.001 ms           0.023 ms    
(       10,       10)           0.030 ms           0.001 ms           0.023 ms    
(       10,      100)           0.053 ms           0.002 ms           0.027 ms    
(       10,     1000)           0.100 ms           0.007 ms           0.055 ms    
(       10,    10000)           0.961 ms           1.031 ms           1.154 ms    
(       10,   100000)           9.772 ms           3.595 ms           3.761 ms    
(       10,  1000000)         105.802 ms          54.260 ms          50.265 ms    
(      100,        1)           0.010 ms           0.001 ms           0.024 ms    
(      100,       10)           0.030 ms           0.002 ms           0.025 ms    
(      100,      100)           0.054 ms           0.002 ms           0.026 ms    
(      100,     1000)           0.105 ms           0.008 ms           0.045 ms    
(      100,    10000)           0.751 ms           1.076 ms           1.158 ms    
(      100,   100000)           9.824 ms           3.253 ms           3.329 ms    
(      100,  1000000)         105.697 ms          57.993 ms          55.285 ms    
(     1000,        1)           0.012 ms           0.005 ms           0.028 ms    
(     1000,       10)           0.038 ms           0.006 ms           0.029 ms    
(     1000,      100)           0.119 ms           0.007 ms           0.033 ms    
(     1000,     1000)           0.180 ms           0.014 ms           0.063 ms    
(     1000,    10000)           0.821 ms           1.074 ms           1.169 ms    
(     1000,   100000)           9.920 ms           3.392 ms           3.532 ms    
(     1000,  1000000)         104.666 ms          57.845 ms          54.603 ms    
(    10000,        1)           0.020 ms           0.041 ms           0.092 ms    
(    10000,       10)           0.089 ms           0.088 ms           0.158 ms    
(    10000,      100)           0.967 ms           0.112 ms           0.182 ms    
(    10000,     1000)           1.017 ms           0.161 ms           0.249 ms    
(    10000,    10000)           1.633 ms           1.137 ms           1.283 ms    
(    10000,   100000)          10.754 ms           3.027 ms           3.302 ms    
(    10000,  1000000)         101.926 ms          48.062 ms          49.117 ms    
(   100000,        1)           0.071 ms           0.409 ms           0.455 ms    
(   100000,       10)           0.575 ms           0.916 ms           0.803 ms    
(   100000,      100)          16.304 ms           1.201 ms           0.940 ms    
(   100000,     1000)          15.185 ms           1.566 ms           1.181 ms    
(   100000,    10000)          15.914 ms           1.454 ms           1.252 ms    
(   100000,   100000)          23.719 ms           4.820 ms           4.313 ms    
(   100000,  1000000)         119.668 ms          56.863 ms          54.570 ms    
(  1000000,        1)           0.774 ms           4.347 ms           3.407 ms    
(  1000000,       10)           6.207 ms           8.793 ms           5.957 ms    
(  1000000,      100)         178.498 ms          13.104 ms           8.544 ms    
(  1000000,     1000)         169.022 ms          16.198 ms          10.283 ms    
(  1000000,    10000)         177.986 ms          13.243 ms           8.973 ms    
(  1000000,   100000)         177.989 ms          19.856 ms          13.898 ms    
(  1000000,  1000000)         283.207 ms          97.118 ms          84.332 ms 

这表明Numba和并行化对于较大的输入非常有益,而对于较小的输入则效率稍低。 然而,在上述大多数测试中,他们仍然优于np.isin()

相关问题 更多 >