python - 不复制数据而重复numpy数组

8 投票
1 回答
1296 浏览
提问于 2025-04-18 06:45

这个问题之前有人问过,但之前的解决办法只适用于一维和二维数组,我需要一个更通用的答案。

我想知道怎么创建一个重复的数组,而不需要复制数据。这对我来说是个很实用的功能,因为这样可以在不占用太多内存的情况下加速Python的操作。

更具体来说,我有一个(y,x)的数组,我想把它重复多次,变成一个(z,y,x)的数组。我可以用numpy.tile(array, (nz,1,1))来做到这一点,但这样会让我内存不够用。我的具体情况是x=1500,y=2000,z=700。

1 个回答

5

一个简单的技巧是使用 np.broadcast_arrays 来将你的 (x, y) 与一个在第一维度上长度为 z 的向量进行广播。

import numpy as np

M = np.arange(1500*2000).reshape(1500, 2000)
z = np.zeros(700)

# broadcasting over the first dimension
_, M_broadcast = np.broadcast_arrays(z[:, None, None], M[None, ...])

print M_broadcast.shape, M_broadcast.flags.owndata
# (700, 1500, 2000), False

为了将 stride_tricks 方法推广到多维数组,你只需要为输出数组的每个维度提供形状和步长信息:

M_strided = np.lib.stride_tricks.as_strided(
                M,                              # input array
                (700, M.shape[0], M.shape[1]),  # output dimensions
                (0, M.strides[0], M.strides[1]) # stride length in bytes
            )

撰写回答