在多维numpy数组中按索引更新

2 投票
1 回答
697 浏览
提问于 2025-04-18 06:49

我正在使用numpy来统计很多大数组中的值,并记录最大值出现的位置。

具体来说,假设我有一个叫做'counts'的数组:

data = numpy.array([[ 5, 10, 3],
                    [ 6, 9, 12],
                    [13, 3,  9],
                    [ 9, 3,  1],
                    ...
                    ])
counts = numpy.zeros(data.shape, dtype=numpy.int)

data会经常变化,但我希望'counts'能够反映每个位置上最大值出现的次数:

max_value_indices = numpy.argmax(data, axis=1)
# this is now [1, 2, 0, 0, ...] representing the positions of 10, 12, 13 and 9, respectively.

根据我对numpy中广播机制的理解,我应该能够这样做:

counts[max_value_indices] += 1

我期待的是这个数组能够被更新:

[[0, 1, 0],
 [0, 0, 1],
 [1, 0, 0],
 [1, 0, 0],
 ...
]

但实际上,这样做却让counts中的所有值都增加了,结果是:

[[1, 1, 1],
 [1, 1, 1],
 [1, 1, 1],
 [1, 1, 1],
 ...
]

我还想过,如果把max_value_indices转换成一个100x1的数组,也许会有效:

counts[max_value_indices[:,numpy.newaxis]] += 1

但这样做的结果只是更新了位置0、1和2的元素:

[[1, 1, 1],
 [1, 1, 1],
 [1, 1, 1],
 [0, 0, 0],
 ...
]

我也乐意把索引数组变成一个只有0和1的数组,然后每次加到counts数组中,但我不太确定该怎么构造这个数组。

1 个回答

2

你可以使用所谓的高级整数索引(也叫多维位置索引)来处理数据:

In [24]: counts[np.arange(data.shape[0]), 
                np.argmax(data, axis=1)] += 1

In [25]: counts
Out[25]: 
array([[0, 1, 0],
       [0, 0, 1],
       [1, 0, 0],
       [1, 0, 0]])

第一个数组,np.arange(data.shape[0]) 是用来指定行的。第二个数组,np.argmax(data, axis=1) 是用来指定列的。

撰写回答