处理Python数组中的维度崩溃
我在使用NumPy时经常遇到一个错误,就是尝试访问数组的某个部分时失败了,因为数组的某个维度是单一的(也就是只有一个元素),所以这个维度被去掉了,无法进行索引。这在处理任意大小的数组时尤其麻烦。我想找到一种简单、通用的方法来避免这个错误。
这里有个例子:
import numpy as np
f = (lambda t, u, i=0: t[:,i]*u[::-1])
a = np.eye(3)
b = np.array([1,2,3])
f(a,b)
f(a[:,0],b[1])
第一次调用是正常的。第二次调用就出问题了:1) t
不能通过 [:,0]
来索引,因为它的形状是 (3,)
,2) u
完全无法索引,因为它是一个标量(也就是一个单一的数)。
我想到了一些解决办法:
1) 在函数 f
内部使用 np.atleast_1d
和 np.atleast_2d
等(可能还需要加一些条件来确保维度顺序正确),这样可以确保所有参数都有需要的维度。这种方法会让我不能使用简化的写法(lambda),而且可能需要写几行代码,我其实不想这么做。
2) 不要写 f(a[:,0],b[1])
,而是用 f(a[:,[0]],b[[1]])
。这样是可以的,但我总是得记得加上额外的括号。如果索引存储在一个变量里,你可能还不确定是否要加上这些额外的括号。例如:
idx = 1
f(a[:,[0]],b[[idx]])
idx = [2,0,1]
f(a[:,[0]],b[idx])
在这种情况下,你似乎得先对 idx
调用 np.atleast_1d
,这可能比直接在函数里加 np.atleast_1d
更麻烦。
3) 在某些情况下,我可以选择不加索引。例如:
f = lambda t, u: t[0]*u
f(a,b)
f(a[:,0],b[0])
这样做是有效的,似乎是最简洁的解决方案,但并不是在所有情况下都适用(特别是你的维度必须一开始就正确)。
那么,有没有比以上方法更好的解决方案呢?
1 个回答
有很多方法可以避免这种情况。
首先,每当你用一个 slice
(切片)来索引一个 np.ndarray
的某个维度,而不是用整数时,输出的维度数量会和输入的维度数量保持一致:
import numpy as np
x = np.arange(12).reshape(3, 4)
print x[:, 0].shape # integer indexing
# (3,)
print x[:, 0:1].shape # slice
# (3, 1)
这是我最喜欢的避免问题的方法,因为它可以很容易地从选择单个元素扩展到选择多个元素(比如 x[:, i:i+1]
和 x[:, i:i+n]
)。
正如你已经提到的,你也可以通过使用任何整数序列来索引某个维度,从而避免维度丢失:
print x[:, [0]].shape # list
# (3, 1)
print x[:, (0,)].shape # tuple
# (3, 1)
print x[:, np.array((0,))].shape # array
# (3, 1)
如果你选择继续使用整数索引,你可以随时通过 np.newaxis
(或者说 None
)来插入一个新的单一维度:
print x[:, 0][:, np.newaxis]
# (3, 1)
print x[:, 0][:, None]
# (3, 1)
或者你也可以手动调整它的形状到正确的大小(这里使用 -1
来自动推断第一个维度的大小):
print x[:, 0].reshape(-1, 1).shape
# (3, 1)
最后,你可以使用 np.matrix
而不是 np.ndarray
。np.matrix
的行为更像 MATLAB 的矩阵,当你用整数索引时,单一维度会被保留:
y = np.matrix(x)
print y[:, 0].shape
# (3, 1)
不过,你要知道 np.matrix
和 np.ndarray
之间还有一些其他重要的区别,比如 *
运算符在数组上执行逐元素相乘,而在矩阵上执行矩阵乘法。在大多数情况下,最好还是使用 np.ndarrays
。