numpy.where(condition)的输出不是数组,而是元组:为什么?

54 投票
3 回答
73236 浏览
提问于 2025-05-10 10:01

我正在尝试使用 numpy.where(condition[, x, y]) 这个函数。
numpy 的文档 中,我了解到如果只输入一个数组,它应该返回这个数组中非零元素的索引(也就是“真”的地方):

如果只给出条件,返回元组 condition.nonzero(),即条件为真的索引。

但是当我尝试时,它返回了一个包含两个元素的 元组,第一个是我想要的索引列表,第二个是一个空元素:

>>> import numpy as np
>>> array = np.array([1,2,3,4,5,6,7,8,9])
>>> np.where(array>4)
(array([4, 5, 6, 7, 8]),) # notice the comma before the last parenthesis

所以问题是:为什么会这样?这种行为有什么目的?在什么情况下这会有用呢?实际上,为了得到我想要的索引列表,我必须加上索引,比如 np.where(array>4)[0],这看起来有点... “丑陋”。


附录

我从一些回答中了解到,这实际上是一个只有一个元素的元组。不过我还是不明白为什么要这样输出。为了说明这并不理想,考虑以下错误(这也是我最初提问的原因):

>>> import numpy as np
>>> array = np.array([1,2,3,4,5,6,7,8,9])
>>> pippo = np.where(array>4)
>>> pippo + 1
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: can only concatenate tuple (not "int") to tuple

所以你需要做一些索引才能访问实际的索引数组:

>>> pippo[0] + 1
array([5, 6, 7, 8, 9])

相关文章:

  • 暂无相关问题
暂无标签

3 个回答

2

只需要使用 np.asarray 这个函数就可以了。在你的情况下:

>>> import numpy as np
>>> array = np.array([1,2,3,4,5,6,7,8,9])
>>> pippo = np.asarray(np.where(array>4))
>>> pippo + 1
array([[5, 6, 7, 8, 9]])
12

简短的回答是:np.where 这个函数的输出是很一致的,不管你用的是几维的数组。

一个二维数组有两个索引,所以 np.where 的结果是一个包含两个相关索引的长度为2的元组。如果是三维数组,结果就是一个长度为3的元组;四维数组则是长度为4的元组;以此类推,对于N维数组,结果就是长度为N的元组。根据这个规则,在一维数组中,结果应该是一个长度为1的元组。

50

在Python中,(1)其实就是1。括号()可以随意加上,用来让数字和表达式看起来更清晰(比如(1+3)*3(1+3,)*3的区别)。所以,如果你想表示一个只有一个元素的元组,就要用(1,)(而且必须这样用)。

因此,

(array([4, 5, 6, 7, 8]),)

是一个只有一个元素的元组,这个元素是一个数组。

如果你对一个二维数组使用where,结果会是一个包含两个元素的元组。

where的结果可以直接用在索引的位置,比如:

a[where(a>0)]
a[a>0]

应该返回和

I,J = where(a>0)   # a is 2d
a[I,J]
a[(I,J)]

一样的结果。

或者用你的例子:

In [278]: a=np.array([1,2,3,4,5,6,7,8,9])
In [279]: np.where(a>4)
Out[279]: (array([4, 5, 6, 7, 8], dtype=int32),)  # tuple

In [280]: a[np.where(a>4)]
Out[280]: array([5, 6, 7, 8, 9])

In [281]: I=np.where(a>4)
In [282]: I
Out[282]: (array([4, 5, 6, 7, 8], dtype=int32),)
In [283]: a[I]
Out[283]: array([5, 6, 7, 8, 9])

In [286]: i, = np.where(a>4)   # note the , on LHS
In [287]: i
Out[287]: array([4, 5, 6, 7, 8], dtype=int32)  # not tuple
In [288]: a[i]
Out[288]: array([5, 6, 7, 8, 9])
In [289]: a[(i,)]
Out[289]: array([5, 6, 7, 8, 9])

======================

np.flatnonzero展示了如何正确地返回一个数组,而不管输入数组的维度。

In [299]: np.flatnonzero(a>4)
Out[299]: array([4, 5, 6, 7, 8], dtype=int32)
In [300]: np.flatnonzero(a>4)+10
Out[300]: array([14, 15, 16, 17, 18], dtype=int32)

它的文档上说:

这等同于ravel().nonzero()[0]

实际上,这就是这个函数的作用。

通过“展平”a,就解决了多维数组该怎么处理的问题。然后它从元组中取出结果,给你一个普通的数组。通过展平,就不需要为一维数组单独处理了。

===========================

@Divakar建议使用np.argwhere

In [303]: np.argwhere(a>4)
Out[303]: 
array([[4],
       [5],
       [6],
       [7],
       [8]], dtype=int32)

它的实现是np.transpose(np.where(a>4))

如果你不喜欢列向量,可以再转置一次

In [307]: np.argwhere(a>4).T
Out[307]: array([[4, 5, 6, 7, 8]], dtype=int32)

这样就变成了一个1xn的数组。

我们也可以把where放在array里:

In [311]: np.array(np.where(a>4))
Out[311]: array([[4, 5, 6, 7, 8]], dtype=int32)

有很多方法可以从where的元组中取出数组([0]i,=transposearray等等)。

撰写回答