提取数组每行的非零元素
我正在使用 Python 3.9 和 NumPy 1.22。
假设我有一个 3x3 的矩阵:
import numpy as np
x = np.array([[10, 40, 0],
[0, 40, 90],
[10, 0, 90]])
这个矩阵里的所有元素都是大于等于 0 的整数。
每一行都有正好 2 个非零的整数。
我想把这些非零的整数提取出来,形成一个 3x2 的矩阵 y
,这样:
y = np.array([[10, 40],
[40, 90],
[10, 90]])
我觉得可以用 numpy.apply_along_axis
、numpy.squeeze
和/或 numpy.where
来实现,但我好像缺少了什么。
2 个回答
0
这里有一个解决方案,但它需要知道非零元素的确切数量(在这个例子中是2个)
x = np.array([[10,40,0],[0,40,90],[10,0,90]]) # exactly 2 non-zero per row
idx = np.where(x>0)
idx
Out[90]:
(array([0, 0, 1, 1, 2, 2], dtype=int64),
array([0, 1, 1, 2, 0, 2], dtype=int64))
y = x[idx].reshape((3,2))
y
Out[92]:
array([[10, 40],
[40, 90],
[10, 90]])
1
因为你知道每一行都有相同数量的零,所以你可以放心地把这些零去掉,然后用reshape
来重新调整数据的形状:
x = np.array([[10,40,0],[0,40,90],[10,0,90]])
out = x[x!=0].reshape(len(x), -1)
输出结果:
array([[10, 40],
[40, 90],
[10, 90]])
如果你想玩得开心一点,假如每一行的零的数量不一样,你可以把它们移动到最后面,然后切割一下,保留最少数量的零:
x = np.array([[10,40,0],[0,40,90],[0,0,90]])
# array([[10, 40, 0],
# [ 0, 40, 90],
# [ 0, 0, 90]])
m = x!=0
out = np.take_along_axis(x, np.argsort(~m, axis=1),
axis=1)[:, :m.sum(axis=1).max()]
# array([[10, 40],
# [40, 90],
# [90, 0]])