在二维矩阵上使用Numpy的where()

31 投票
2 回答
97517 浏览
提问于 2025-04-18 09:47

我有一个这样的矩阵

t = np.array([[1,2,3,'foo'],
 [2,3,4,'bar'],
 [5,6,7,'hello'],
 [8,9,1,'bar']])

我想找出哪些行包含字符串 'bar'

结果应该是一个一维数组

rows = np.where(t == 'bar')

应该给我这些索引 [0,3],然后进行广播操作:-

results = t[rows]

这样应该能得到正确的行

但是我不知道怎么在二维数组中实现这个功能。

2 个回答

31

你需要把数组切片,选择你想要索引的那一列:

rows = np.where(t[:,3] == 'bar')
result = t[rows]

这样会返回:

 [[2,3,4,'bar'],
  [8,9,1,'bar']]
31

对于一般情况,也就是说你的搜索字符串可以出现在任何一列,你可以这样做:

>>> rows, cols = np.where(t == 'bar')
>>> t[rows]
array([['2', '3', '4', 'bar'],
       ['8', '9', '1', 'bar']],
      dtype='|S11')

撰写回答