在numpy中按维度连接

3 投票
1 回答
2008 浏览
提问于 2025-04-18 16:07

我有一个叫做 x 的东西。

[[[ 1  2  3]
  [ 4  5  6]]

 [[ 7  8  9]
  [10 11 12]]] # shape (2,2,3)

我想要做的是

[[ 1  2  3  4  5  6]
 [ 7  8  9 10 11 12]] # shape (2,6)

也就是说,我想把中间那一维的所有项目连接起来。

在这个特定的情况下,我可以用下面的方法得到这个结果:

x.reshape(2, 2*3)

或者更抽象一点说

x.reshape(x.shape[0], x.shape[1]*x.shape[2])

有没有一种简洁、符合 numpy 风格的方法,可以对任意维度的 x 得到这个结果,最好是我不需要自己去算索引?

我一直在尝试使用 concatenate 函数,但没有成功。

1 个回答

4

如果你只关心第一个维度的大小,可以使用

x.reshape(x.shape[0], -1)

这里的 -1 表示这个维度的大小会自动计算出来。对于更高维度的数组也是适用的,只要在新维度的元组中不超过一个 -1

你也可以通过直接给数组的 shape 属性赋值来做到这一点:

x.shape = (x.shape[0], -1)

使用 x.reshape(...) 和直接赋值给 x.shape 的主要区别在于,前者如果不能在不改变底层内存的情况下改变 x 的形状,可能会创建一个副本(比如说,如果 x 是不连续的),而后者则不会生成副本,而是会抛出一个 AttributeError 错误。

撰写回答