处理Python数组中的维度崩溃

2 投票
1 回答
2086 浏览
提问于 2025-04-18 11:15

我在使用NumPy时经常遇到一个错误,就是尝试访问数组的某个部分时失败了,因为数组的某个维度是单一的(也就是只有一个元素),所以这个维度被去掉了,无法进行索引。这在处理任意大小的数组时尤其麻烦。我想找到一种简单、通用的方法来避免这个错误。

这里有个例子:

import numpy as np
f = (lambda t, u, i=0: t[:,i]*u[::-1])
a = np.eye(3)
b = np.array([1,2,3])
f(a,b)
f(a[:,0],b[1])

第一次调用是正常的。第二次调用就出问题了:1) t 不能通过 [:,0] 来索引,因为它的形状是 (3,),2) u 完全无法索引,因为它是一个标量(也就是一个单一的数)。

我想到了一些解决办法:

1) 在函数 f 内部使用 np.atleast_1dnp.atleast_2d 等(可能还需要加一些条件来确保维度顺序正确),这样可以确保所有参数都有需要的维度。这种方法会让我不能使用简化的写法(lambda),而且可能需要写几行代码,我其实不想这么做。

2) 不要写 f(a[:,0],b[1]),而是用 f(a[:,[0]],b[[1]])。这样是可以的,但我总是得记得加上额外的括号。如果索引存储在一个变量里,你可能还不确定是否要加上这些额外的括号。例如:

idx = 1
f(a[:,[0]],b[[idx]])
idx = [2,0,1]
f(a[:,[0]],b[idx])

在这种情况下,你似乎得先对 idx 调用 np.atleast_1d,这可能比直接在函数里加 np.atleast_1d 更麻烦。

3) 在某些情况下,我可以选择不加索引。例如:

f = lambda t, u: t[0]*u
f(a,b)
f(a[:,0],b[0])

这样做是有效的,似乎是最简洁的解决方案,但并不是在所有情况下都适用(特别是你的维度必须一开始就正确)。

那么,有没有比以上方法更好的解决方案呢?

1 个回答

2

有很多方法可以避免这种情况。

首先,每当你用一个 slice(切片)来索引一个 np.ndarray 的某个维度,而不是用整数时,输出的维度数量会和输入的维度数量保持一致:

import numpy as np

x = np.arange(12).reshape(3, 4)
print x[:, 0].shape               # integer indexing
# (3,)

print x[:, 0:1].shape             # slice
# (3, 1)

这是我最喜欢的避免问题的方法,因为它可以很容易地从选择单个元素扩展到选择多个元素(比如 x[:, i:i+1]x[:, i:i+n])。

正如你已经提到的,你也可以通过使用任何整数序列来索引某个维度,从而避免维度丢失:

print x[:, [0]].shape             # list
# (3, 1)

print x[:, (0,)].shape            # tuple
# (3, 1)

print x[:, np.array((0,))].shape  # array
# (3, 1)

如果你选择继续使用整数索引,你可以随时通过 np.newaxis(或者说 None)来插入一个新的单一维度:

print x[:, 0][:, np.newaxis]
# (3, 1)

print x[:, 0][:, None]
# (3, 1)

或者你也可以手动调整它的形状到正确的大小(这里使用 -1 来自动推断第一个维度的大小):

print x[:, 0].reshape(-1, 1).shape
# (3, 1)

最后,你可以使用 np.matrix 而不是 np.ndarraynp.matrix 的行为更像 MATLAB 的矩阵,当你用整数索引时,单一维度会被保留:

y = np.matrix(x)
print y[:, 0].shape
# (3, 1)

不过,你要知道 np.matrixnp.ndarray 之间还有一些其他重要的区别,比如 * 运算符在数组上执行逐元素相乘,而在矩阵上执行矩阵乘法。在大多数情况下,最好还是使用 np.ndarrays

撰写回答