在numpy中对分区索引进行分组argmax/argmin

5 投票
2 回答
1619 浏览
提问于 2025-04-17 20:28

Numpy的ufunc有一个叫做reduceat的方法,可以在数组的连续部分上运行这些函数。这样,我就不需要写:

import numpy as np
a = np.array([4, 0, 6, 8, 0, 9, 8, 5, 4, 9])
split_at = [4, 5]
maxima = [max(subarray for subarray in np.split(a, split_at)]

我可以写:

maxima = np.maximum.reduceat(a, np.hstack([0, split_at]))

这两种写法都会返回切片a[0:4]a[4:5]a[5:10]中的最大值,结果是[8, 0, 9]

我想要一个类似的功能来执行argmax,注意我只想要每个部分的一个最大值的索引:对于上面的asplit_at,结果应该是[3, 4, 5](尽管在最后一组中,索引5和9都得到了最大值),就像下面的代码所返回的那样:

np.hstack([0, split_at]) + [np.argmax(subarray) for subarray in np.split(a, split_at)]

我会在下面发布一个可能的解决方案,但我希望看到一个不需要为组创建索引的向量化解决方案。

2 个回答

1

受到这个问题的启发,我在numpy_indexed这个包里增加了argmin和argmax的功能。下面是相应的测试代码。请注意,键的顺序可以是任意的(并且可以是npi支持的任何类型):

def test_argmin():
    keys   = [2, 0, 0, 1, 1, 2, 2, 2, 2, 2]
    values = [4, 5, 6, 8, 0, 9, 8, 5, 4, 9]
    unique, amin = group_by(keys).argmin(values)
    npt.assert_equal(unique, [0, 1, 2])
    npt.assert_equal(amin,   [1, 4, 0])
1

这个解决方案涉及到为一组数据建立一个索引(在上面的例子中是 [0, 0, 0, 0, 1, 2, 2, 2, 2, 2])。

group_lengths = np.diff(np.hstack([0, split_at, len(a)]))
n_groups = len(group_lengths)
index = np.repeat(np.arange(n_groups), group_lengths)

然后我们可以使用:

maxima = np.maximum.reduceat(a, np.hstack([0, split_at]))
all_argmax = np.flatnonzero(np.repeat(maxima, group_lengths) == a)
result = np.empty(len(group_lengths), dtype='i')
result[index[all_argmax[::-1]]] = all_argmax[::-1]

来得到 [3, 4, 5] 这个结果在 result 中。这里的 [::-1] 确保我们得到的是每组中的第一个最大值,而不是最后一个。

这个方法依赖于一个事实,就是在复杂赋值中,最后的索引决定了被赋的值。@seberg 提到过,这个方法不太可靠(而且有一个更安全的替代方案可以用 result = all_argmax[np.unique(index[all_argmax], return_index=True)[1]] 来实现,这个方法需要对 len(maxima) ~ n_groups 个元素进行排序)。

撰写回答