理解numpy的dstack函数
我对numpy的dstack
函数有点困惑。它的说明写得很简略,只说:
按深度顺序堆叠数组(沿着第三个轴)。
接受一系列数组,并沿着第三个轴将它们堆叠成一个单一的数组。重新构建通过
dsplit
分割的数组。这是将2D数组(图像)简单堆叠成一个3D数组以便处理的方法。
所以,要么我真的很笨,这个意思显而易见,要么我对“堆叠”、“顺序”、“深度方向”或“沿着一个轴”的这些术语有误解。不过,我觉得在vstack
和hstack
的上下文中,我是理解这些术语的。
让我们来看这个例子:
In [193]: a
Out[193]:
array([[0, 3],
[1, 4],
[2, 5]])
In [194]: b
Out[194]:
array([[ 6, 9],
[ 7, 10],
[ 8, 11]])
In [195]: dstack([a,b])
Out[195]:
array([[[ 0, 6],
[ 3, 9]],
[[ 1, 7],
[ 4, 10]],
[[ 2, 8],
[ 5, 11]]])
首先,a
和b
没有第三个轴,那我怎么能沿着“第三个轴”来堆叠它们呢?其次,假设a
和b
是2D图像的表示,为什么我最后得到的结果是三个2D数组,而不是两个“顺序”的2D数组呢?
4 个回答
因为你提到了“图像”,我觉得这个例子会很有用。如果你在用Keras训练一个二维卷积网络,并且输入是X,那么最好把X的维度保持为(#图像,图像的第一个维度,图像的第二个维度)。
image1 = np.array([[4,2],[5,5]])
image2 = np.array([[3,1],[6,7]])
image1 = image1.reshape(1,2,2)
image2 = image2.reshape(1,2,2)
X = np.stack((image1,image2),axis=1)
X
array([[[[4, 2],
[5, 5]],
[[3, 1],
[6, 7]]]])
np.shape(X)
X = X.reshape((2,2,2))
X
array([[[4, 2],
[5, 5]],
[[3, 1],
[6, 7]]])
X[0] # image 1
array([[4, 2],
[5, 5]])
X[1] # image 2
array([[3, 1],
[6, 7]])
我想试着用图形化的方式来解释这个问题(虽然接受的答案已经很清楚了,但我花了几秒钟才理解明白)。
如果我们把二维数组想象成一个列表的列表,其中第一个轴代表其中一个内部列表,第二个轴则表示该列表中的值,那么原问题中的数组可以这样表示:
a = [
[0, 3],
[1, 4],
[2, 5]
]
b = [
[6, 9],
[7, 10],
[8, 11]
]
# Shape of each array is [3,2]
现在,根据当前的文档,dstack
函数会添加一个第三个轴,这意味着每个数组的样子会变成这样:
a = [
[[0], [3]],
[[1], [4]],
[[2], [5]]
]
b = [
[[6], [9]],
[[7], [10]],
[[8], [11]]
]
# Shape of each array is [3,2,1]
然后,把这两个数组在第三个维度上堆叠,结果应该看起来像这样:
dstack([a,b]) = [
[[0, 6], [3, 9]],
[[1, 7], [4, 10]],
[[2, 8], [5, 11]]
]
# Shape of the combined array is [3,2,2]
希望这能帮到你。
假设我们有一个叫做 x == dstack([a, b])
的东西。这里的 x[:, :, 0]
和 a
是一模一样的,而 x[:, :, 1]
和 b
也是一模一样的。一般来说,当我们把二维数组用 dstack 叠起来时,输出的结果会是这样的:output[:, :, n]
和第 n 个输入数组是一样的。
如果我们叠的是三维数组,而不是二维数组:
x = numpy.zeros([2, 2, 3])
y = numpy.ones([2, 2, 4])
z = numpy.dstack([x, y])
那么 z[:, :, :3]
就和 x
一样,而 z[:, :, 3:7]
就和 y
一样。
正如你所看到的,我们需要在第三个轴上进行切片,才能恢复出传给 dstack
的输入。这就是 dstack
为什么会这样工作的原因。
要理解 np.vstack
、np.hstack
和 np.dstack
这几个函数的作用,最简单的方法就是看看它们输出数组的 .shape
属性。
以你提供的两个示例数组为例:
print(a.shape, b.shape)
# (3, 2) (3, 2)
np.vstack
是在第一个维度上进行拼接...print(np.vstack((a, b)).shape) # (6, 2)
np.hstack
是在第二个维度上进行拼接...print(np.hstack((a, b)).shape) # (3, 4)
而
np.dstack
则是在第三个维度上进行拼接。print(np.dstack((a, b)).shape) # (3, 2, 2)
因为 a
和 b
都是二维数组,所以 np.dstack
会通过插入一个大小为1的第三维来扩展它们。这就相当于用 np.newaxis
(或者说 None
)在第三维上进行索引,像这样:
print(a[:, :, np.newaxis].shape)
# (3, 2, 1)
如果 c = np.dstack((a, b))
,那么 c[:, :, 0] == a
和 c[:, :, 1] == b
。
你也可以用 np.concatenate
更明确地进行同样的操作,像这样:
print(np.concatenate((a[..., None], b[..., None]), axis=2).shape)
# (3, 2, 2)
* 使用 import *
将整个模块的内容导入到全局命名空间被认为是有很多原因的坏习惯。更好的做法是 import numpy as np
。