有没有快速的方法将numpy数组中的一个元素与其他元素比较?

5 投票
3 回答
2783 浏览
提问于 2025-04-17 18:42

我有一个数组,我想检查这个数组里有没有哪个元素大于或等于其他任何元素。我可以用两个循环来实现,但我的数组长度有一万或者更多,这样做会让程序变得很慢。有没有更快的方法呢?

[编辑] 我只需要检查当前元素后面的元素,如果有大于或等于的,我还需要知道那个元素的位置。

[编辑] 我想更清楚地解释我的问题,因为现在的解决方案不适合我需要的。首先,这里有一些代码

x=linspace(-10, 10, 10000)
t=linspace(0,5,10000)

u=np.exp(-x**2)

k=u*t+x

我把一个x数组放进高斯函数里,得到它的高度,然后根据这个高度,得出这个x值在空间中传播的速度,这个速度是通过k来找到的。我的问题是,我需要找出高斯函数什么时候变成双值函数(换句话说,就是什么时候会发生冲击)。如果我用argmax的方法,我总是会得到k中的最后一个值,因为它非常接近零,我需要的是在给我双值的那个元素之后的第一个值。

[编辑] 小例子

x=[0,1,2,3,4,5,6,7,8,9,10] #Input 
k=[0,1,2,3,4,5,6,5,4,10] #adjusted for speed

output I want
in this case, 5 is the first number that goes above a number that comes after it.
So I need to know the index of where 5 is located and possibly the index 
of the number that it is greater than

3 个回答

2

编辑
其实,处理一个包含10,000个项目的Python循环,比处理一个包含100,000,000个项目的数组要便宜得多。

In [14]: np.where(np.array([True if np.all(k[:j] <= k[j]) else
                            False for j in xrange(len(k))]) == 0)
Out[14]: (array([5129, 5130, 5131, ..., 6324, 6325, 6326]),)

In [15]: %timeit np.where(np.array([True if np.all(k[:j] <= k[j]) else
                                    False for j in xrange(len(k))]) == 0)
1 loops, best of 3: 201 ms per loop

在内存方面,这会比较耗费,但你可以通过广播来加速搜索。如果你这样做:

>>> k <= k[:, None]
array([[ True, False, False, ..., False, False, False],
       [ True,  True, False, ..., False, False, False],
       [ True,  True,  True, ..., False, False, False],
       ..., 
       [ True,  True,  True, ...,  True, False, False],
       [ True,  True,  True, ...,  True,  True, False],
       [ True,  True,  True, ...,  True,  True,  True]], dtype=bool)

返回的结果是一个布尔数组,其中位置 [i, j] 的值告诉你 k[j] 是否小于或等于 k[i]。你可以使用 np.cumprod,方法如下:

>>> np.cumprod(k <= k[:, None], axis=1)
array([[1, 0, 0, ..., 0, 0, 0],
       [1, 1, 0, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       ..., 
       [1, 1, 1, ..., 1, 0, 0],
       [1, 1, 1, ..., 1, 1, 0],
       [1, 1, 1, ..., 1, 1, 1]])

在这里,位置 [i, j] 的值告诉你 k[j] 是否小于或等于所有在 k[:i] 中的项目。如果你取这个矩阵的对角线:

>>> np.cumprod(k <= k[:, None], axis=1)[np.diag_indices(k.shape[0])]
array([1, 1, 1, ..., 1, 1, 1])

位置 [i] 的值告诉你 k[i] 是否小于或等于它之前的所有项目。找到那个数组中值为零的位置:

>>> np.where(np.cumprod(k <= k[:, None],
...                     axis=1)[np.diag_indices(k.shape[0])] == 0)
(array([5129, 5130, 5131, ..., 6324, 6325, 6326]),)

你就能得到所有满足你想要条件的值的索引。

如果你只对第一个感兴趣:

>>> np.argmax(np.cumprod(k <= k[:, None],
...                      axis=1)[np.diag_indices(k.shape[0])] == 0)
5129

这不是一个轻松的操作,但如果你的内存足够大,可以容纳所有的布尔数组,那你就不会等太久:

In [3]: %timeit np.argmax(np.cumprod(k <= k[:, None],
                                     axis=1)[np.diag_indices(k.shape[0])] == 0)
1 loops, best of 3: 948 ms per loop
3

这是一个向量化的解决方案,比ecatmur的方法快大约25%。

np.where(k > np.min(k[np.where(np.diff(k) < 0)[0][0]:]))[0][0]

这是一个简单直接的方法:

next(i for i in np.arange(len(arr)) if arr[i:].argmin() != 0)
5

第一个比后面的值大的数,必然是局部最小值中的最小值:

k = np.array([0,1,2,3,4,5,6,5,4,10])
lm_i = np.where(np.diff(np.sign(np.diff(k))) > 0)[0] + 1
mlm = np.min(k[lm_i])
mlm_i = lm_i[np.argmin(k[lm_i])]

第一个比后面值大的数的索引,就是那个局部最小值之后的第一个索引:

i = np.where(k > mlm)[0][0]

解决方案的图

(忽略图表在切线处似乎穿过水平线的情况;那只是显示上的小问题。)

简单来说:

np.where(k > np.min(k[np.where(np.diff(np.sign(np.diff(k))) > 0)[0] + 1]))[0][0]

注意,这个方法大约比其他方法快1000倍,因为它完全是向量化的:

%timeit np.where(k > np.min(k[np.where(np.diff(np.sign(np.diff(k))) > 0)[0] + 1]))[0][0]
1000 loops, best of 3: 228 us per loop

撰写回答