在numpy中按维度连接
我有一个叫做 x
的东西。
[[[ 1 2 3]
[ 4 5 6]]
[[ 7 8 9]
[10 11 12]]] # shape (2,2,3)
我想要做的是
[[ 1 2 3 4 5 6]
[ 7 8 9 10 11 12]] # shape (2,6)
也就是说,我想把中间那一维的所有项目连接起来。
在这个特定的情况下,我可以用下面的方法得到这个结果:
x.reshape(2, 2*3)
或者更抽象一点说
x.reshape(x.shape[0], x.shape[1]*x.shape[2])
有没有一种简洁、符合 numpy 风格的方法,可以对任意维度的 x
得到这个结果,最好是我不需要自己去算索引?
我一直在尝试使用 concatenate
函数,但没有成功。
1 个回答
4
如果你只关心第一个维度的大小,可以使用
x.reshape(x.shape[0], -1)
这里的 -1
表示这个维度的大小会自动计算出来。对于更高维度的数组也是适用的,只要在新维度的元组中不超过一个 -1
。
你也可以通过直接给数组的 shape
属性赋值来做到这一点:
x.shape = (x.shape[0], -1)
使用 x.reshape(...)
和直接赋值给 x.shape
的主要区别在于,前者如果不能在不改变底层内存的情况下改变 x
的形状,可能会创建一个副本(比如说,如果 x
是不连续的),而后者则不会生成副本,而是会抛出一个 AttributeError
错误。