找出numpy数组中非零前的零的数量

8 投票
6 回答
5029 浏览
提问于 2025-04-18 03:56

我有一个名为 A 的numpy数组。我想高效地计算在 A 中,非零元素之前有多少个零,尤其是在循环中。

比如说,如果 A = np.array([0,1,2]),那么 np.nonzero(A)[0][0] 会返回1,意思是第一个非零元素前面有1个零。但是如果 A = np.array([0,0,0]),这个方法就不管用了(在这种情况下,我想要的答案是3)。而且如果 A 非常大,而第一个非零元素又在前面,这样的做法似乎效率不高。

6 个回答

1

简单方法有什么问题:

def countLeadingZeros(x):
""" Count number of elements up to the first non-zero element, return that count """
    ctr = 0
    for k in x:
        if k == 0:
            ctr += 1
        else: #short circuit evaluation, we found a non-zero so return immediately
            return ctr
    return ctr #we get here in the case that x was all zeros

这个方法一旦找到一个非零的元素就会立即返回,所以在最坏的情况下,它的时间复杂度是O(n),也就是说如果数组很长,它可能会检查每一个元素。你可以把这个方法用C语言来写,这样可能会更快,但最好先测试一下,看看对于你正在处理的数组,这样做是否真的有必要。

1

我很惊讶为什么还没有人使用过 np.where

np.where(a)[0][0] if np.shape(np.where(a)[0])[0] != 0 else np.shape(a)[0] 这个方法可以解决问题。

>> a = np.array([0,1,2])
>> np.where(a)[0][0] if np.shape(np.where(a)[0])[0] != 0  else np.shape(a)[0]
... 1
>> a = np.array([0,0,0))
>> np.where(a)[0][0] if np.shape(np.where(a)[0])[0] != 0  else np.shape(a)[0]
... 3
>> a = np.array([1,2,3))
>> np.where(a)[0][0] if np.shape(np.where(a)[0])[0] != 0  else np.shape(a)[0]
... 0
4

在数组的末尾加上一个非零的数字,你仍然可以使用np.nonzero来得到你想要的结果。

A = np.array([0,1,2])
B = np.array([0,0,0])

np.min(np.nonzero(np.hstack((A, 1))))   # --> 1
np.min(np.nonzero(np.hstack((B, 1))))   # --> 3
4
i = np.argmax(A!=0)
if i==0 and np.all(A==0): i=len(A)

这个方法应该是最有效的解决方案,不需要额外的插件。而且它也很容易进行向量化处理,可以同时在多个方向上操作。

3

这里有一个迭代的Cython版本,如果你遇到严重的性能瓶颈,这可能是你最好的选择。

# saved as file count_leading_zeros.pyx
import numpy as np
cimport numpy as np
cimport cython

DTYPE = np.int
ctypedef np.int_t DTYPE_t

@cython.boundscheck(False)
def count_leading_zeros(np.ndarray[DTYPE_t, ndim=1] a):
    cdef int elements = a.size
    cdef int i = 0
    cdef int count = 0
    while i < elements:
        if a[i] == 0:
            count += 1
        else:
            return count
        i += 1
    return count

这个方法和@mtrw的回答类似,但在索引时速度更快。我对Cython的理解还不够深入,所以可能还有进一步改进的空间。

我用IPython快速测试了一下几个不同的方法,结果非常理想。

In [1]: import numpy as np

In [2]: import pyximport; pyximport.install()
Out[2]: (None, <pyximport.pyximport.PyxImporter at 0x53e9250>)

In [3]: import count_leading_zeros

In [4]: %paste
def count_leading_zeros_python(x):
    ctr = 0
    for k in x:
        if k == 0:
            ctr += 1
        else:
            return ctr
    return ctr
## -- End pasted text --
In [5]: a = np.zeros((10000000,), dtype=np.int)

In [6]: a[5] = 1

In [7]: 

In [7]: %timeit np.min(np.nonzero(np.hstack((a, 1))))
10 loops, best of 3: 91.1 ms per loop

In [8]: 

In [8]: %timeit np.where(a)[0][0] if np.shape(np.where(a)[0])[0] != 0  else np.shape(a)[0]
10 loops, best of 3: 107 ms per loop

In [9]: 

In [9]: %timeit count_leading_zeros_python(a)
100000 loops, best of 3: 3.87 µs per loop

In [10]: 

In [10]: %timeit count_leading_zeros.count_leading_zeros(a)
1000000 loops, best of 3: 489 ns per loop

不过,我只会在有证据(通过性能分析工具)证明这是个瓶颈的时候才使用这样的代码。很多东西看起来可能效率不高,但其实花时间去修复它们并不值得。

撰写回答