在numpy中创建索引数组消除double for循环

2024-06-07 13:11:09 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一些物理模拟代码,用python编写,使用numpy/scipy。分析代码显示38%的CPU时间都花在一个双嵌套的for循环中—这看起来太过分了,所以我一直在尝试减少它。在

循环的目标是创建一个索引数组,显示2D数组的元素等于1D数组中的哪些元素。在

indices[i,j] = where(1D_array == 2D_array[i,j])

例如,如果1D_array = [7.2, 2.5, 3.9]

^{pr2}$

我们应该有

indices = [[0, 1]
           [2, 0]]

我目前已将此实现为

for i in range(ni):
    for j in range(nj):
        out[i, j] = (1D_array - 2D_array[i, j]).argmin()

{I>不一定要处理浮点数。我知道一维数组中的每个数字都是唯一的,并且2D数组中的每个元素都有一个匹配项,所以这种方法给出了正确的结果。在

有没有办法消除双for循环?在

注意

我需要索引数组来执行以下操作:

f = complex_function(1D_array)
output = f[indices]

这比其他方法更快,因为2D数组的大小是NxN,而1D数组的大小是1xN,并且2D数组有许多重复值。如果有人可以提出一种不同的方法来获得相同的输出,而不必经过索引数组,那么这也可能是一个解决方案


Tags: 方法代码innumpy元素目标for时间
3条回答

为了消除两个Python for循环,您可以通过向数组中添加新的轴(使它们可以相互广播)来“一次性”完成所有的等式比较。在

请记住,这将生成一个包含len(arr1)*len(arr2)值的新数组。如果这是一个非常大的数字,这种方法可能是不可行的,这取决于你的记忆的局限性。否则,应该相当快:

>>> (arr1[:,np.newaxis] == arr2[:,np.newaxis]).argmax(axis=1)
array([[0, 1],
       [2, 0]], dtype=int32)

如果需要获取arr1中最接近匹配值的索引,请使用:

^{pr2}$

在纯Python中,您可以在O(N)时间内使用字典执行此操作,唯一的时间惩罚是涉及到的Python循环:

>>> arr1 = np.array([7.2, 2.5, 3.9])
>>> arr2 = np.array([[7.2, 2.5], [3.9, 7.2]])
>>> indices = dict(np.hstack((arr1[:, None], np.arange(3)[:, None])))
>>> np.fromiter((indices[item] for item in arr2.ravel()), dtype=arr2.dtype).reshape(arr2.shape)
array([[ 0.,  1.],
       [ 2.,  0.]])

其他人建议的dictionary方法可能有效,但它要求您提前知道目标数组(2d数组)中的每个元素在搜索数组(1d数组)中都有一个完全匹配的。即使这在原则上应该是真的,您仍然必须处理浮点精度问题,例如,尝试这个.1 * 3 == .3。在

另一种方法是使用numpy的searchsorted函数。searchsorted获取一个经过排序的1d搜索数组,然后任何traget数组都会为目标数组中的每个项找到搜索数组中最近的元素。我已经根据您的情况修改了这个answer,请看一下find_closest函数是如何工作的。在

import numpy as np

def find_closest(A, target):
    order = A.argsort()
    A = A[order]

    idx = A.searchsorted(target)
    idx = np.clip(idx, 1, len(A)-1)
    left = A[idx-1]
    right = A[idx]
    idx -= target - left < right - target
    return order[idx]

array1d = np.array([7.2, 2.5, 3.9])
array2d = np.array([[7.2, 2.5],
                    [3.9, 7.2]])

indices = find_closest(array1d, array2d)
print(indices)
# [[0 1]
#  [2 0]]

相关问题 更多 >

    热门问题