理解numpy.r_()拼接的语法

2024-06-16 14:37:16 发布

您现在位置:Python中文网/ 问答频道 /正文

我在numpy文档中为函数r_阅读了以下内容:

A string integer specifies which axis to stack multiple comma separated arrays along. A string of two comma-separated integers allows indication of the minimum number of dimensions to force each entry into as the second integer (the axis to concatenate along is still the first integer).

他们举了一个例子:

>>> np.r_['0,2', [1,2,3], [4,5,6]] # concatenate along first axis, dim>=2
array([[1, 2, 3],
       [4, 5, 6]])

我不明白,字符串'0,2'到底指示numpy做什么?

除了上面的链接,还有没有其他网站有更多关于这个功能的文档?


Tags: oftheto函数文档numpystringinteger
3条回答

字符串“0,2”告诉numpy沿0轴(第一个轴)连接,并将元素包装在足够多的括号中,以确保二维数组。考虑以下结果:

for axis in (0,1):
    for minDim in (1,2,3):
        print np.r_['{},{}'.format(axis, minDim), [1,2,30, 31], [4,5,6, 61], [7,8,90, 91], [10,11, 12, 13]], 'axis={}, minDim={}\n'.format(axis, minDim)

[ 1  2 30 31  4  5  6 61  7  8 90 91 10 11 12 13] axis=0, minDim=1

[[ 1  2 30 31]
 [ 4  5  6 61]
 [ 7  8 90 91]
 [10 11 12 13]] axis=0, minDim=2

[[[ 1  2 30 31]]

 [[ 4  5  6 61]]

 [[ 7  8 90 91]]

 [[10 11 12 13]]] axis=0, minDim=3

[ 1  2 30 31  4  5  6 61  7  8 90 91 10 11 12 13] axis=1, minDim=1

[[ 1  2 30 31  4  5  6 61  7  8 90 91 10 11 12 13]] axis=1, minDim=2

[[[ 1  2 30 31]
  [ 4  5  6 61]
  [ 7  8 90 91]
  [10 11 12 13]]] axis=1, minDim=3

您突出显示的段落是两个逗号分隔的整数语法,这是三个逗号分隔语法的特例。一旦您理解了三个逗号分隔的语法,两个逗号分隔的语法就就位了。

对于您的示例,等效的三个逗号分隔整数语法为:

np.r_['0,2,-1', [1,2,3], [4,5,6]]

为了提供更好的解释,我将上述内容改为:

np.r_['0,2,-1', [1,2,3], [[4,5,6]]]

以上分为两部分:

  1. 逗号分隔的整数字符串

  2. 两个逗号分隔的数组

逗号分隔的数组具有以下形状:

np.array([1,2,3]).shape
(3,)

np.array([[4,5,6]]).shape
(1, 3)

换句话说,第一个“数组”是“一维”,而第二个“数组”是“二维”。

首先,0,2,-1中的2意味着每个array都应该升级,这样它就必须至少是2-dimensional。因为第二个array已经2-dimensional了,所以它不受影响。然而,第一个array1-dimensional,为了使其成为2-dimensionalnp.r_,需要在其形状tuple中添加1,使其成为(1,3)(3,1)。这就是-1中的0,2,-1发挥作用的地方。它基本上决定了额外的1需要放在array的形状tuple中的什么位置。-1是默认值,并将1(或者1s,如果需要更多维度的话)放在形状tuple的前面(我在下面进一步解释原因)。这将第一个array's形状tuple转换为(1,3),这与第二个array's形状tuple相同。0,2,-1中的0意味着生成的数组需要沿“0”轴连接。

由于两个arrays现在都有一个tuple形状的(1,3)连接是可能的,因为如果将两个arrays中的连接轴(上面示例中的维度0的值都为1)放在一边,则剩余维度是相等的(在这种情况下,arrays中的剩余维度的值都是3)。如果不是这样,则会产生以下错误:

ValueError: all the input array dimensions except for the concatenation axis must match exactly

现在,如果将具有形状(1,3)的两个arrays连接起来,则生成的array将具有形状(1+1,3) == (2,3),因此:

np.r_['0,2,-1', [1,2,3], [[4,5,6]]].shape
(2, 3)

当使用0或正整数作为逗号分隔字符串中的第三个整数时,该整数确定升级形状tuple中每个array's形状元组的开始(仅适用于需要升级其维度的arrays)。例如,0,2,0意味着对于需要形状升级的arrays来说,array's原始形状tuple应该从升级形状tuple的维度0开始。对于形状为tuple(3,)array[1,2,3],将1放置在3之后。这将导致形状tuple等于(3,1),正如您所看到的,原始形状tuple(3,)从升级形状tuple的维度0开始。0,2,1意味着对于[1,2,3]array's形状tuple(3,)应该从升级的形状元组的维度1开始。这意味着1需要放置在维度0处。生成的形状元组将是(1,3)

当逗号分隔字符串中的第三个整数使用负数时,负号后面的整数将确定原始形状元组的结束位置。当原始形状元组是(3,)0,2,-1时,意味着原始形状元组应该结束于升级形状元组的最后一个维度,因此1将放置在升级形状元组的维度0处,而升级形状元组将是(1,3)。现在(3,)结束于升级形状元组的维度1,该维度也是升级形状元组的最后一个维度(原始数组是[1,2,3],升级数组是[[1,2,3]])。

np.r_['0,2', [1,2,3], [4,5,6]]

np.r_['0,2,-1', [1,2,3], [4,5,6]]

最后,这里是一个具有更多维度的示例:

np.r_['2,4,1',[[1,2],[4,5],[10,11]],[7,8,9]].shape
(1, 3, 3, 1)

逗号分隔的ar光线是:

有形状元组的[[1,2],[4,5],[10,11]]

有形状元组的[7,8,9]

两个arrays都需要升级到4-dimensional arrays。原始的array's形元组需要从维度1开始。

因此,对于第一个数组,形状变为(1,3,2,1),因为3,2从维度1开始,并且因为需要添加两个1s使其4-dimensional一个1放在原始形状元组之前,一个1放在之后。

使用相同的逻辑,第二个数组的形状元组变成(1,3,1,1)

现在需要使用维度2作为连接轴来连接这两个arrays。从每个数组的升级形状元组中删除维度2会导致这两个元组的元组都是(1,3,1)。由于得到的元组是相同的,所以可以将数组连接起来,并对连接的轴求和,以生成(1, 3, 2+1, 1) == (1, 3, 3, 1)

'n,m'告诉r_沿axis=n连接,并生成至少具有m维度的形状:

In [28]: np.r_['0,2', [1,2,3], [4,5,6]]
Out[28]: 
array([[1, 2, 3],
       [4, 5, 6]])

所以我们沿着轴=0连接,因此我们通常希望结果具有形状(6,),但是由于m=2,我们告诉r_形状必须至少是二维的。所以我们得到的是形状(2,3)

In [32]: np.r_['0,2', [1,2,3,], [4,5,6]].shape
Out[32]: (2, 3)

看看当我们增加m时会发生什么:

In [36]: np.r_['0,3', [1,2,3,], [4,5,6]].shape
Out[36]: (2, 1, 3)    # <- 3 dimensions

In [37]: np.r_['0,4', [1,2,3,], [4,5,6]].shape
Out[37]: (2, 1, 1, 3) # <- 4 dimensions

使用r_可以做的任何事情都可以使用一个更可读的数组构建函数来完成,例如np.concatenatenp.row_stacknp.column_stacknp.hstacknp.vstacknp.dstack,尽管它可能还需要调用reshape

即使调用Reforme,这些其他功能也可能更快:

In [38]: %timeit np.r_['0,4', [1,2,3,], [4,5,6]]
10000 loops, best of 3: 38 us per loop
In [43]: %timeit np.concatenate(([1,2,3,], [4,5,6])).reshape(2,1,1,3)
100000 loops, best of 3: 10.2 us per loop

相关问题 更多 >