提取数组每行的非零元素

2 投票
2 回答
46 浏览
提问于 2025-04-13 19:46

我正在使用 Python 3.9 和 NumPy 1.22。

假设我有一个 3x3 的矩阵:

import numpy as np
x = np.array([[10, 40, 0],
              [0, 40, 90],
              [10, 0, 90]])

这个矩阵里的所有元素都是大于等于 0 的整数。

每一行都有正好 2 个非零的整数。

我想把这些非零的整数提取出来,形成一个 3x2 的矩阵 y,这样:

y = np.array([[10, 40],
              [40, 90],
              [10, 90]])

我觉得可以用 numpy.apply_along_axisnumpy.squeeze 和/或 numpy.where 来实现,但我好像缺少了什么。

2 个回答

0

这里有一个解决方案,但它需要知道非零元素的确切数量(在这个例子中是2个)

x = np.array([[10,40,0],[0,40,90],[10,0,90]])  # exactly 2 non-zero per row
idx = np.where(x>0)
idx
Out[90]: 
(array([0, 0, 1, 1, 2, 2], dtype=int64),
 array([0, 1, 1, 2, 0, 2], dtype=int64))
y = x[idx].reshape((3,2))
y
Out[92]: 
array([[10, 40],
       [40, 90],
       [10, 90]])
1

因为你知道每一行都有相同数量的零,所以你可以放心地把这些零去掉,然后用reshape来重新调整数据的形状:

x = np.array([[10,40,0],[0,40,90],[10,0,90]])

out = x[x!=0].reshape(len(x), -1)

输出结果:

array([[10, 40],
       [40, 90],
       [10, 90]])

如果你想玩得开心一点,假如每一行的零的数量不一样,你可以把它们移动到最后面,然后切割一下,保留最少数量的零:

x = np.array([[10,40,0],[0,40,90],[0,0,90]])
# array([[10, 40,  0],
#        [ 0, 40, 90],
#        [ 0,  0, 90]])

m = x!=0
out = np.take_along_axis(x, np.argsort(~m, axis=1),
                         axis=1)[:, :m.sum(axis=1).max()]
# array([[10, 40],
#        [40, 90],
#        [90,  0]])

撰写回答