无复制连接Numpy数组
在Numpy中,我可以用 np.append
或 np.concatenate
把两个数组拼接在一起:
>>> X = np.array([[1,2,3]])
>>> Y = np.array([[-1,-2,-3],[4,5,6]])
>>> Z = np.append(X, Y, axis=0)
>>> Z
array([[ 1, 2, 3],
[-1, -2, -3],
[ 4, 5, 6]])
不过,这些方法会把输入的数组复制一份:
>>> Z[0,:] = 0
>>> Z
array([[ 0, 0, 0],
[-1, -2, -3],
[ 4, 5, 6]])
>>> X
array([[1, 2, 3]])
有没有办法可以把两个数组拼接成一个 视图,也就是说不复制?这样做需要创建一个 np.ndarray
的子类吗?
6 个回答
2
我也遇到过同样的问题,最后我采取了相反的做法。在正常拼接(用复制的方式)之后,我把原来的数组重新指向拼接后的数组,这样它们就变成了拼接数组的视图。
import numpy as np
def concat_no_copy(arrays):
""" Concats the arrays and returns the concatenated array
in addition to the original arrays as views of the concatenated one.
Parameters:
-----------
arrays: list
the list of arrays to concatenate
"""
con = np.concatenate(arrays)
viewarrays = []
for i, arr in enumerate(arrays):
arrnew = con[sum(len(a) for a in arrays[:i]):
sum(len(a) for a in arrays[:i + 1])]
viewarrays.append(arrnew)
assert all(arr == arrnew)
# return the view arrays, replace the old ones with these
return con, viewarrays
你可以通过以下方式来测试:
def test_concat_no_copy():
arr1 = np.array([0, 1, 2, 3, 4])
arr2 = np.array([5, 6, 7, 8, 9])
arr3 = np.array([10, 11, 12, 13, 14])
arraylist = [arr1, arr2, arr3]
con, newarraylist = concat_no_copy(arraylist)
assert all(con == np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
11, 12, 13, 14]))
for old, new in zip(arraylist, newarraylist):
assert all(old == new)
17
在你往数组里填数据之前,先把数组初始化一下。如果你想的话,可以分配比实际需要更多的空间,这样也不会占用更多的内存,因为numpy的工作方式就是这样的。
A = np.zeros(R,C)
A[row] = [data]
内存只有在你把数据放进数组时才会被使用。把两个数组合并成一个新数组,如果数据集很大,比如超过1GB,就可能永远也处理不完。
99
Numpy数组的内存必须是连续的。如果你单独分配了多个数组,它们在内存中会随机分散,这样就无法把它们表示成一个视图的Numpy数组。
如果你事先知道需要多少个数组,可以先分配一个大的数组,然后让每个小数组作为这个大数组的视图(比如通过切片来获取)。