无复制连接Numpy数组

88 投票
6 回答
56191 浏览
提问于 2025-04-17 04:51

在Numpy中,我可以用 np.appendnp.concatenate 把两个数组拼接在一起:

>>> X = np.array([[1,2,3]])
>>> Y = np.array([[-1,-2,-3],[4,5,6]])
>>> Z = np.append(X, Y, axis=0)
>>> Z
array([[ 1,  2,  3],
       [-1, -2, -3],
       [ 4,  5,  6]])

不过,这些方法会把输入的数组复制一份:

>>> Z[0,:] = 0
>>> Z
array([[ 0,  0,  0],
       [-1, -2, -3],
       [ 4,  5,  6]])
>>> X
array([[1, 2, 3]])

有没有办法可以把两个数组拼接成一个 视图,也就是说不复制?这样做需要创建一个 np.ndarray 的子类吗?

6 个回答

2

我也遇到过同样的问题,最后我采取了相反的做法。在正常拼接(用复制的方式)之后,我把原来的数组重新指向拼接后的数组,这样它们就变成了拼接数组的视图。

import numpy as np

def concat_no_copy(arrays):
    """ Concats the arrays and returns the concatenated array 
    in addition to the original arrays as views of the concatenated one.

    Parameters:
    -----------
    arrays: list
        the list of arrays to concatenate
    """
    con = np.concatenate(arrays)

    viewarrays = []
    for i, arr in enumerate(arrays):
        arrnew = con[sum(len(a) for a in arrays[:i]):
                     sum(len(a) for a in arrays[:i + 1])]
        viewarrays.append(arrnew)
        assert all(arr == arrnew)

    # return the view arrays, replace the old ones with these
    return con, viewarrays

你可以通过以下方式来测试:

def test_concat_no_copy():
    arr1 = np.array([0, 1, 2, 3, 4])
    arr2 = np.array([5, 6, 7, 8, 9])
    arr3 = np.array([10, 11, 12, 13, 14])

    arraylist = [arr1, arr2, arr3]

    con, newarraylist = concat_no_copy(arraylist)

    assert all(con == np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 
                                11, 12, 13, 14]))

    for old, new in zip(arraylist, newarraylist):
        assert all(old == new)
17

在你往数组里填数据之前,先把数组初始化一下。如果你想的话,可以分配比实际需要更多的空间,这样也不会占用更多的内存,因为numpy的工作方式就是这样的。

A = np.zeros(R,C)
A[row] = [data]

内存只有在你把数据放进数组时才会被使用。把两个数组合并成一个新数组,如果数据集很大,比如超过1GB,就可能永远也处理不完。

99

Numpy数组的内存必须是连续的。如果你单独分配了多个数组,它们在内存中会随机分散,这样就无法把它们表示成一个视图的Numpy数组。

如果你事先知道需要多少个数组,可以先分配一个大的数组,然后让每个小数组作为这个大数组的视图(比如通过切片来获取)。

撰写回答