我正在读一些深层次的学习代码。我对numpy数组的高级索引有问题。我正在测试的代码:
import numpy
x = numpy.arange(2 * 8 * 3 * 64).reshape((2, 8, 3, 64))
x.shape
p1 = numpy.arange(2)[:, None]
sd = numpy.ones(2 * 64, dtype=int).reshape((2, 64))
p4 = numpy.arange(128 // 2)[None, :]
y = x[p1, :, sd, p4]
y.shape
为什么y
的形状是(2, 64, 8)
?你知道吗
以下是上述代码的输出:
>>> x.shape
(2, 8, 3, 64)
>>> p1
array([[0], [1]])
>>> sd
array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
>>> p4
array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]])
>>> y.shape
(2, 64, 8)
我读到:https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
我认为这与广播有关:
x
形状是(2, 8, 3, 64)
。你知道吗
p1
很简单,它的array([[0], [1]]),
意思就是选择第一维度的ind0, 1
。双阵是用来广播的。你知道吗
p2
是:
,意思是选择第二维度中的所有8个元素。你知道吗
p3
是一个棘手的问题,它包含两个“列表”从维度3的3个元素中选择一个,因此产生的新的第3维度应该是1。你知道吗
p4
意味着它选择了第四维中的所有64个元素。你知道吗
所以我认为y.shape
应该是(2, 8, 1, 64)
。你知道吗
但正确的答案是(2, 64, 8)
。为什么?你知道吗
当我第一次在numpy中遇到花哨的索引时,我也遇到了同样的问题。简单的回答是没有什么诀窍:花哨的索引只是将元素选择到与索引形状相同的输出中。使用纯粹的花哨索引,输出数组的形状将与广播的索引数组(described here)相同。输出的形状与输入的形状几乎没有任何关系,除非你也加入一个规则的切片索引(described here)。你的情况是后者,这增加了混乱。你知道吗
让我们看看你的指数,看看发生了什么:
关于如何进行的具体文件如下:
强调我的
请记住,在上述两种情况下,花式索引部分的维度都是索引数组的维度,而不是正在索引的数组。你知道吗
那么,您应该期望看到的是广播维度是
p1
、sd
和p4
(2, 64
),其次是第二维度的大小x
(8
)。这确实是你得到的:相关问题 更多 >
编程相关推荐