基于2D掩码数组的numpy 3D到2D转换
如果我有一个像这样的ndarray:
>>> a = np.arange(27).reshape(3,3,3)
>>> a
array([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]],
[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]],
[[18, 19, 20],
[21, 22, 23],
[24, 25, 26]]])
我知道可以通过使用 np.max(axis=...)
来获取某个轴上的最大值:
>>> a.max(axis=2)
array([[ 2, 5, 8],
[11, 14, 17],
[20, 23, 26]])
另外,我也可以获取沿着那个轴的索引,这些索引对应于最大值:
>>> indices = a.argmax(axis=2)
>>> indices
array([[2, 2, 2],
[2, 2, 2],
[2, 2, 2]])
我的问题是——给定数组 indices
和数组 a
,有没有简单的方法来重现 a.max(axis=2)
返回的数组?
这样做可能有效:
import itertools as it
import numpy as np
def apply_mask(field,indices):
data = np.empty(indices.shape)
#It seems highly likely that there is a more numpy-approved way to do this.
idx = [range(i) for i in indices.shape]
for idx_tup,zidx in zip(it.product(*idx),indices.flat):
data[idx_tup] = field[idx_tup+(zidx,)]
return data
但是,这看起来有点笨拙/效率低下。而且,这种方法也只适用于“最后一个”轴。我想知道有没有numpy的函数(或者一些神奇的numpy索引用法)可以让这个工作?简单的 a[:,:,a.argmax(axis=2)]
是不行的。
更新:
似乎下面的方法也有效(而且看起来更好):
import numpy as np
def apply_mask(field,indices):
data = np.empty(indices.shape)
for idx_tup,zidx in np.ndenumerate(indices):
data[idx_tup] = field[idx_tup+(zidx,)]
return data
我想这样做是因为我想根据一个数组中的数据提取索引(通常使用 argmax(axis=...)
),然后用这些索引从其他一堆(形状相同的)数组中提取数据。我也愿意尝试其他方法来实现这个目标(例如,使用布尔掩码数组)。不过,我喜欢使用这些“索引”数组带来的“安全感”。这样我可以确保得到正确数量的元素,来创建一个看起来像是3D场中的2D“切片”的新数组。
2 个回答
-1
我使用 index_at()
来创建完整的索引:
import numpy as np
def index_at(idx, shape, axis=-1):
if axis<0:
axis += len(shape)
shape = shape[:axis] + shape[axis+1:]
index = list(np.ix_(*[np.arange(n) for n in shape]))
index.insert(axis, idx)
return tuple(index)
a = np.random.randint(0, 10, (3, 4, 5))
axis = 1
idx = np.argmax(a, axis=axis)
print a[index_at(idx, a.shape, axis=axis)]
print np.max(a, axis=axis)
3
这里有一些神奇的numpy索引方法,可以满足你的需求,但不幸的是,这些代码看起来相当难懂。
def apply_mask(a, indices, axis):
magic_index = [np.arange(i) for i in indices.shape]
magic_index = np.ix_(*magic_index)
magic_index = magic_index[:axis] + (indices,) + magic_index[axis:]
return a[magic_index]
或者同样难懂的:
def apply_mask(a, indices, axis):
magic_index = np.ogrid[tuple(slice(i) for i in indices.shape)]
magic_index.insert(axis, indices)
return a[magic_index]