在numpy数组切片操作中,是否可以组合逻辑条件和限制条件

2024-04-27 04:02:30 发布

您现在位置:Python中文网/ 问答频道 /正文

我有下面的代码,完全可以做我想要的,但是太慢了,因为它涉及不必要的具体化步骤:

### init
a = np.array([[1,2,3],[4,5,6],[7,8,9],[10,11,12]])

### condition 1) element 0 has to be larger than 1
### condition 2) limit the output to 2 elements
b = a[a[:,0] > 1][:2]

问题是当我有一个大数组时,这个过程非常慢(假设我只想用条件2切掉一小块)。这是很容易做到的,但我还没有找到一种方法把它放在一条直线上。在

因此,有没有一种简洁的方法可以在一个一行代码中有效地完成这一任务?像这样:

^{pr2}$

谢谢你!在


Tags: theto方法代码initnp步骤element
3条回答

我想不出在直接numpy中更快的解决方案,但是您可能可以使用numba做得更好:

from numba import autojit

def filtfunc(a):
    idx = []
    for ii in range(a.shape[0]):
        if (a[ii, 0] > 1):
            idx.append(ii)
            if (len(idx) == 2):
                break
    return a[idx]

jit_filter = autojit(filtfunc)

以下是另外两个建议的解决方案供参考:

^{pr2}$

一些时间安排:

^{3}$

您可能可以将此代码加快一点,您当前的代码如下所示:

a = np.array([[1,2,3],[4,5,6],[7,8,9],[10,11,12]])

# Check your condition
mask = a[:, 0] > 1

# Copy those rows the array that satisfy the condition 
temp = a[mask]

# Take first two rows of temp
b = temp[:2]

我怀疑最昂贵的操作是中间的复制操作,您可以尝试这样做来避免它,方法是:

^{pr2}$

可能有一种更有效的方法来找到前两个真值,我没有考虑太多,但关键是先找到您想要的值,然后只复制这些行。在

如果n=2与{}相比非常小,那么使用这个小函数可能是值得的。其基本思想是计算一个足够大的掩码,以提供所需的最终行数。在这里我迭代地做。通常迭代很慢,但是如果迭代次数足够少,那么在别处节省的时间是值得的。在

def mask(a):
    return a[:,0]>1

def paul_filter1(a,n):
    # incremental w/ sum
    j = a.shape[0]
    for i in xrange(n,j+1):
        am = mask(a[:i,:])
        if np.sum(am)>=n:
            j = i
            break
    return a[am,:]

请注意,掩码am可以比它正在处理的维度短。它有效地用False填充其余部分。我还没检查这是否有记录。在

在这个小例子中,fooa[a[:,0]>1,:][:2,:]慢3倍。在

但是对于一个更大的数组,比如a2=np.tile(a,[1000,1]),使用foo的时间保持不变,但是“蛮力”不断变慢,因为它必须将掩码应用到更多行。当然,这些计时确实取决于a中期望行的位置。如果foo必须使用几乎所有的行,则不会有任何节省。在

编辑

为了解决Bi-Rico对重复的np.sum(即使是快速编译的代码),我们可以逐步构建where

^{pr2}$

对于小的n这甚至更快。在

更接近原始方法的方法是计算完整的遮罩,然后进行修剪。cumsum可用于查找最小长度。在

^{3}$

1000x1000随机整数数组(1:12)测试,times为(使用20而不是2,并调整掩码以使更多的行为假)

In [172]: timeit paul_filter4(a,20)
1000 loops, best of 3: 690 us per loop

In [173]: timeit paul_filter3(a,20)
1000 loops, best of 3: 1.22 ms per loop

In [175]: timeit paul_filter1(a,20)
1000 loops, best of 3: 994 us per loop

In [176]: timeit rico_filter(a,20)
1000 loops, best of 3: 668 us per loop

In [177]: timeit marco_filter(a,20)
10 loops, best of 3: 21 ms per loop  

rico_filter使用where是最快的,但是我的另一种使用cumsum的方法并不落后。3个增量过滤器的速度相似,大约是快速过滤器的一半。在

在这个生成并测试的a中,大多数行是True。这与marco's关注的极限条件是逻辑条件的一个小子集是一致的。在这些条件下,比里科担心paul_filter1可能爆炸是不现实的。在

如果我更改了测试参数,那么a的所有行都必须被测试(a[:,0]>11), 然后使用wherecumsum的过滤器所用的时间与原始过滤器一样长。增量过滤器速度较慢,为15倍或更多。但我第一次尝试使用np.sum是这种风格中最快的。在

相关问题 更多 >