numpy数组的累计argmax

2024-04-24 05:31:03 发布

您现在位置:Python中文网/ 问答频道 /正文

考虑数组a

np.random.seed([3,1415])
a = np.random.randint(0, 10, (10, 2))
a

array([[0, 2],
       [7, 3],
       [8, 7],
       [0, 6],
       [8, 6],
       [0, 2],
       [0, 4],
       [9, 7],
       [3, 2],
       [4, 3]])

什么是向量化的方法来获得累计的argmax?在

^{pr2}$

这是一种非矢量化的方法。在

请注意,随着窗口的扩展,argmax将应用于不断增长的窗口。在

pd.DataFrame(a).expanding().apply(np.argmax).astype(int).values

array([[0, 0],
       [1, 1],
       [2, 2],
       [2, 2],
       [2, 2],
       [2, 2],
       [2, 2],
       [7, 2],
       [7, 2],
       [7, 2]])

Tags: 方法dataframenprandom数组array矢量化seed
3条回答

我想不出一种方法可以轻松地将这两个列矢量化;但是如果列数相对于行数来说很小,那么这应该不是问题,对于该轴,for循环就足够了:

import numpy as np
import numpy_indexed as npi
a = np.random.randint(0, 10, (10))
max = np.maximum.accumulate(a)
idx = npi.indices(a, max)
print(idx)

下面是一个矢量化的纯numy解决方案,它的执行非常迅速:

def cumargmax(a):
    m = np.maximum.accumulate(a)
    x = np.repeat(np.arange(a.shape[0])[:, None], a.shape[1], axis=1)
    x[1:] *= m[:-1] < m[1:]
    np.maximum.accumulate(x, axis=0, out=x)
    return x

然后我们有:

^{pr2}$

对具有数千到数百万个值的数组进行的一些快速测试表明,这比Python级别的循环(隐式或显式)快10-50倍。在

我想创建一个函数,计算1d数组的累计argmax,然后将其应用于所有列。代码如下:

import numpy as np

np.random.seed([3,1415])
a = np.random.randint(0, 10, (10, 2))

def cumargmax(v):
    uargmax = np.frompyfunc(lambda i, j: j if v[j] > v[i] else i, 2, 1)
    return uargmax.accumulate(np.arange(0, len(v)), 0, dtype=np.object).astype(v.dtype)

np.apply_along_axis(cumargmax, 0, a)

转换为np.object然后再转换回来的原因是nump1.9的一种解决方法,如generalized cumulative functions in NumPy/SciPy?所述

相关问题 更多 >