理解numpy.r_()拼接的语法
我在numpy的文档中看到关于函数 r_ 的说明:
一个字符串整数可以指定多个用逗号分隔的数组要沿着哪个轴进行堆叠。如果是两个用逗号分隔的整数的字符串,则可以指明每个条目最少要强制进入的维度数,第二个整数表示要拼接的轴,而第一个整数仍然是拼接的轴。
他们还给了一个例子:
>>> np.r_['0,2', [1,2,3], [4,5,6]] # concatenate along first axis, dim>=2
array([[1, 2, 3],
[4, 5, 6]])
我不太明白,字符串 '0,2'
到底是让numpy做什么呢?
除了上面的链接,还有没有其他网站提供关于这个函数的更多文档?
3 个回答
字符串 '0,2' 告诉 numpy 在第一个轴(也就是轴 0)上进行连接,并且要用足够的括号把元素包起来,以确保形成一个二维数组。来看一下下面的结果:
for axis in (0,1):
for minDim in (1,2,3):
print np.r_['{},{}'.format(axis, minDim), [1,2,30, 31], [4,5,6, 61], [7,8,90, 91], [10,11, 12, 13]], 'axis={}, minDim={}\n'.format(axis, minDim)
[ 1 2 30 31 4 5 6 61 7 8 90 91 10 11 12 13] axis=0, minDim=1
[[ 1 2 30 31]
[ 4 5 6 61]
[ 7 8 90 91]
[10 11 12 13]] axis=0, minDim=2
[[[ 1 2 30 31]]
[[ 4 5 6 61]]
[[ 7 8 90 91]]
[[10 11 12 13]]] axis=0, minDim=3
[ 1 2 30 31 4 5 6 61 7 8 90 91 10 11 12 13] axis=1, minDim=1
[[ 1 2 30 31 4 5 6 61 7 8 90 91 10 11 12 13]] axis=1, minDim=2
[[[ 1 2 30 31]
[ 4 5 6 61]
[ 7 8 90 91]
[10 11 12 13]]] axis=1, minDim=3
你提到的那段话讲的是两个用逗号分隔的整数的语法,这其实是三组用逗号分隔的语法的一种特殊情况。一旦你理解了三组逗号分隔的语法,两个逗号分隔的语法就会变得容易理解。
对于你的例子,三组逗号分隔的整数语法可以写成:
np.r_['0,2,-1', [1,2,3], [4,5,6]]
为了更好地解释,我将上面的内容改成:
np.r_['0,2,-1', [1,2,3], [[4,5,6]]]
上面有两个部分:
一个用逗号分隔的整数字符串
两个用逗号分隔的数组
这两个数组的形状如下:
np.array([1,2,3]).shape
(3,)
np.array([[4,5,6]]).shape
(1, 3)
换句话说,第一个“数组”是一维的,而第二个“数组”是二维的。
首先,0,2,-1
中的2
表示每个array
都应该被升级,至少要变成二维的。因为第二个array
已经是二维的,所以不受影响。然而,第一个array
是一维的,为了把它变成二维的,np.r_
需要在它的形状tuple
中添加一个1,使其变成(1,3)
或(3,1)
。这就是0,2,-1
中的-1
的作用。它基本上决定了额外的1应该放在array
的形状tuple
的哪个位置。-1
是默认值,会把1
(如果需要更多维度就是1s
)放在形状tuple
的前面(我会在下面进一步解释)。这使得第一个array
的形状tuple
变成(1,3)
,这和第二个array
的形状tuple
是一样的。0
在0,2,-1
中表示结果数组需要沿着“0”轴进行拼接。
因为现在两个arrays
的形状tuple
都是(1,3)
,所以可以进行拼接。只要把拼接轴(在上面的例子中是维度0,值为1)排除,两个arrays
的剩余维度是相等的(在这个例子中,两个arrays
的剩余维度值都是3)。如果不是这样,就会出现以下错误:
ValueError: 所有输入数组的维度必须完全匹配,除了拼接轴。
现在,如果你拼接两个形状为(1,3)
的arrays
,结果的array
将会是(1+1,3) == (2,3)
,因此:
np.r_['0,2,-1', [1,2,3], [[4,5,6]]].shape
(2, 3)
当在逗号分隔的字符串中使用0
或正整数作为第三个整数时,这个整数决定了每个array
的形状tuple
在升级后的形状tuple
中的起始位置(仅对那些需要升级维度的arrays
有效)。例如0,2,0
意味着对于需要形状升级的arrays
,array
的原始形状tuple
应该从升级后的形状tuple
的维度0开始。对于形状为(3,)
的array
[1,2,3]
,1
会放在3
之后。这样就会得到形状tuple
为(3,1)
,可以看到原始形状tuple
(3,)
是从升级后的形状tuple
的维度0
开始的。0,2,1
意味着对于[1,2,3]
,array
的形状tuple
(3,)
应该从升级后的形状tuple
的维度1开始。这意味着1需要放在维度0。结果的形状tuple
将是(1,3)
。
当在逗号分隔的字符串中使用负数作为第三个整数时,负号后的整数决定了原始形状tuple
应该结束的位置。当原始形状tuple
为(3,)
时,0,2,-1
意味着原始形状tuple
应该在升级后的形状tuple
的最后一个维度结束,因此1会放在升级后的形状tuple
的维度0上,升级后的形状tuple
将是(1,3)
。现在(3,)
在升级后的形状tuple
的维度1结束,这也是升级后的形状tuple
的最后一个维度(原始数组是[1,2,3]
,升级后的数组是[[1,2,3]]
)。
np.r_['0,2', [1,2,3], [4,5,6]]
与以下内容相同:
np.r_['0,2,-1', [1,2,3], [4,5,6]]
最后,这里有一个更高维度的例子:
np.r_['2,4,1',[[1,2],[4,5],[10,11]],[7,8,9]].shape
(1, 3, 3, 1)
用逗号分隔的数组是:
[[1,2],[4,5],[10,11]]
,它的形状tuple
是(3,2)
[7,8,9]
,它的形状tuple
是(3,)
这两个arrays
都需要升级为四维数组。原始array
的形状tuple
需要从维度1开始。
因此,第一个数组的形状变为(1,3,2,1)
,因为3,2
从维度1开始,并且因为需要添加两个1来使其变为四维,所以一个1放在原始形状tuple
之前,一个1放在之后。
使用相同的逻辑,第二个数组的形状tuple
变为(1,3,1,1)
。
现在这两个arrays
需要使用维度2作为拼接轴进行拼接。去掉每个数组升级后的形状tuple
的维度2,结果是两个arrays
的形状tuple
都是(1,3,1)
。因为结果的tuple
是相同的,所以可以拼接这两个数组,拼接轴的维度相加得到(1, 3, 2+1, 1) == (1, 3, 3, 1)
。
'n,m'
是在告诉 r_
要沿着 axis=n
进行拼接,并且结果的形状至少要有 m
个维度:
In [28]: np.r_['0,2', [1,2,3], [4,5,6]]
Out[28]:
array([[1, 2, 3],
[4, 5, 6]])
所以我们是沿着轴0进行拼接,通常情况下我们会期待结果的形状是 (6,)
,但是因为 m=2
,我们在告诉 r_
结果的形状必须至少是二维的。因此我们得到的形状是 (2,3)
:
In [32]: np.r_['0,2', [1,2,3,], [4,5,6]].shape
Out[32]: (2, 3)
看看当我们增加 m
时会发生什么:
In [36]: np.r_['0,3', [1,2,3,], [4,5,6]].shape
Out[36]: (2, 1, 3) # <- 3 dimensions
In [37]: np.r_['0,4', [1,2,3,], [4,5,6]].shape
Out[37]: (2, 1, 1, 3) # <- 4 dimensions
用 r_
能做到的事情,也可以用一些更易读的数组构建函数来完成,比如 np.concatenate
、np.row_stack
、np.column_stack
、np.hstack
、np.vstack
或 np.dstack
,不过可能还需要调用 reshape
。
即使调用了 reshape
,那些其他函数可能会更快:
In [38]: %timeit np.r_['0,4', [1,2,3,], [4,5,6]]
10000 loops, best of 3: 38 us per loop
In [43]: %timeit np.concatenate(([1,2,3,], [4,5,6])).reshape(2,1,1,3)
100000 loops, best of 3: 10.2 us per loop