在numpy中对分区索引进行分组argmax/argmin
Numpy的ufunc
有一个叫做reduceat
的方法,可以在数组的连续部分上运行这些函数。这样,我就不需要写:
import numpy as np
a = np.array([4, 0, 6, 8, 0, 9, 8, 5, 4, 9])
split_at = [4, 5]
maxima = [max(subarray for subarray in np.split(a, split_at)]
我可以写:
maxima = np.maximum.reduceat(a, np.hstack([0, split_at]))
这两种写法都会返回切片a[0:4]
、a[4:5]
和a[5:10]
中的最大值,结果是[8, 0, 9]
。
我想要一个类似的功能来执行argmax
,注意我只想要每个部分的一个最大值的索引:对于上面的a
和split_at
,结果应该是[3, 4, 5]
(尽管在最后一组中,索引5和9都得到了最大值),就像下面的代码所返回的那样:
np.hstack([0, split_at]) + [np.argmax(subarray) for subarray in np.split(a, split_at)]
我会在下面发布一个可能的解决方案,但我希望看到一个不需要为组创建索引的向量化解决方案。
2 个回答
1
受到这个问题的启发,我在numpy_indexed这个包里增加了argmin和argmax的功能。下面是相应的测试代码。请注意,键的顺序可以是任意的(并且可以是npi支持的任何类型):
def test_argmin():
keys = [2, 0, 0, 1, 1, 2, 2, 2, 2, 2]
values = [4, 5, 6, 8, 0, 9, 8, 5, 4, 9]
unique, amin = group_by(keys).argmin(values)
npt.assert_equal(unique, [0, 1, 2])
npt.assert_equal(amin, [1, 4, 0])
1
这个解决方案涉及到为一组数据建立一个索引(在上面的例子中是 [0, 0, 0, 0, 1, 2, 2, 2, 2, 2]
)。
group_lengths = np.diff(np.hstack([0, split_at, len(a)]))
n_groups = len(group_lengths)
index = np.repeat(np.arange(n_groups), group_lengths)
然后我们可以使用:
maxima = np.maximum.reduceat(a, np.hstack([0, split_at]))
all_argmax = np.flatnonzero(np.repeat(maxima, group_lengths) == a)
result = np.empty(len(group_lengths), dtype='i')
result[index[all_argmax[::-1]]] = all_argmax[::-1]
来得到 [3, 4, 5]
这个结果在 result
中。这里的 [::-1]
确保我们得到的是每组中的第一个最大值,而不是最后一个。
这个方法依赖于一个事实,就是在复杂赋值中,最后的索引决定了被赋的值。@seberg 提到过,这个方法不太可靠(而且有一个更安全的替代方案可以用 result = all_argmax[np.unique(index[all_argmax], return_index=True)[1]]
来实现,这个方法需要对 len(maxima) ~ n_groups
个元素进行排序)。