python - 不复制数据而重复numpy数组
这个问题之前有人问过,但之前的解决办法只适用于一维和二维数组,我需要一个更通用的答案。
我想知道怎么创建一个重复的数组,而不需要复制数据。这对我来说是个很实用的功能,因为这样可以在不占用太多内存的情况下加速Python的操作。
更具体来说,我有一个(y,x)的数组,我想把它重复多次,变成一个(z,y,x)的数组。我可以用numpy.tile(array, (nz,1,1))来做到这一点,但这样会让我内存不够用。我的具体情况是x=1500,y=2000,z=700。
1 个回答
5
一个简单的技巧是使用 np.broadcast_arrays
来将你的 (x, y)
与一个在第一维度上长度为 z
的向量进行广播。
import numpy as np
M = np.arange(1500*2000).reshape(1500, 2000)
z = np.zeros(700)
# broadcasting over the first dimension
_, M_broadcast = np.broadcast_arrays(z[:, None, None], M[None, ...])
print M_broadcast.shape, M_broadcast.flags.owndata
# (700, 1500, 2000), False
为了将 stride_tricks
方法推广到多维数组,你只需要为输出数组的每个维度提供形状和步长信息:
M_strided = np.lib.stride_tricks.as_strided(
M, # input array
(700, M.shape[0], M.shape[1]), # output dimensions
(0, M.strides[0], M.strides[1]) # stride length in bytes
)