在二维矩阵上使用Numpy的where()
我有一个这样的矩阵
t = np.array([[1,2,3,'foo'],
[2,3,4,'bar'],
[5,6,7,'hello'],
[8,9,1,'bar']])
我想找出哪些行包含字符串 'bar'
结果应该是一个一维数组
rows = np.where(t == 'bar')
应该给我这些索引 [0,3],然后进行广播操作:-
results = t[rows]
这样应该能得到正确的行
但是我不知道怎么在二维数组中实现这个功能。
2 个回答
31
你需要把数组切片,选择你想要索引的那一列:
rows = np.where(t[:,3] == 'bar')
result = t[rows]
这样会返回:
[[2,3,4,'bar'],
[8,9,1,'bar']]
31
对于一般情况,也就是说你的搜索字符串可以出现在任何一列,你可以这样做:
>>> rows, cols = np.where(t == 'bar')
>>> t[rows]
array([['2', '3', '4', 'bar'],
['8', '9', '1', 'bar']],
dtype='|S11')