Python 训练数组过大

0 投票
2 回答
5403 浏览
提问于 2025-04-18 18:13

我有超过3000万个对象需要用作我的训练数据。我的问题很简单:当我通过不断添加的方式创建我的训练数组时,到了某个临界点,列表变得太大,Python就崩溃了。有没有什么办法可以解决这个问题?我已经尝试了几个小时,但一直没有找到答案!

创建训练数组的代码示例

training_array = []
for ...:
    data = #load data from somewhere
    data_array = [x for x in data] #some large array, 2-3 million objects  
    for item in data_array:
        training_array.append(item.a + item.b)

过了一段时间,控制台上会打印出 "killed",然后Python就退出了。我该如何避免这种情况呢?

更具体的问题描述:

我正在尝试在一个非常非常大的数组上进行训练,但在创建这个数组时,Python就崩溃了。这个训练算法不能分批处理数据,而是需要一个完整的数组,这限制了我之前想到的解决办法。有没有其他方法可以创建这个数组,而不占用我所有的内存(如果这真的是问题所在)?

2 个回答

0

你可以做几件事:

  1. 分块处理数据 - 这样可以避免一次性处理一个超级大的数组,也能减少系统负担。

  2. 使用生成器来生成数据 - 生成器是“懒惰”计算的,这意味着它们不会一次性全部存在。每当你需要某个元素时,它才会被创建,而不是提前生成,这样就不会出现超级大的数组。如果你不熟悉生成器,可能会有点难理解,但网上有很多相关的资源可以帮助你。

针对你的具体问题,可以试试这个生成器:

def train_gen(data):
    data_gen = (x for x in data)  #The () here are important as it makes data_gen a generator as well, as opposed to a list
    for item in data_gen:
        yield item.a + item.b


data = #load data from somewhere
training_array = train_gen(data)

for item in training_array:
    #Iterates through training_array, producing one value, then discarding it such that only one item in training_array is in memory at a time
4
  • data 是一个 Python 列表吗?如果是的话,

    data_array = [x for x in data]
    

    就没必要这样做,因为这和说

    data_array = list(data)
    

    是一样的。这会复制一份 data,这样会占用两倍的内存,但不清楚这样做有什么用。

    另外,你可以使用 del data 来让 Python 回收不再需要的 data 占用的内存。

  • 另一方面,可能 data 是一个迭代器。如果是这样的话,你可以通过避免创建 Python 列表 data_array 来节省内存。特别是,你不需要 data_array 来定义 training_array。你可以用

    data_array = [x for x in data] #some large array, 2-3 million objects  
    for item in data_array:
        training_array.append(item.a + item.b)
    

    替换为列表推导式

    training_array = [x.a + x.b for x in data]
    
  • 如果你在使用 NumPy,并且最终想让 training_array 成为一个 NumPy 数组,那么你可以通过避免创建中间的 Python 列表 training_array 来节省更多内存。你可以直接从 data 定义 NumPy 数组 training_data

    training_array = np.fromiter((x.a + x.b for x in data),
                                 dtype=...)
    

    注意 (x.a + x.b for x in data) 是一个生成器表达式,这样就避免了如果使用列表推导式所需的更大内存。

    如果你知道 data 的长度,在调用 np.fromiter 时添加 count=... 会加快性能,因为这可以让 NumPy 预先分配合适的内存给最终的数组。

    你还需要指定正确的数据类型。如果 training_array 中的值是浮点数,你可以通过指定一个更小的 itemsize 的数据类型来节省内存(虽然会牺牲一些精度)。例如,dtype='float32' 会用 4 字节(即 32 位)来存储数组中的每个浮点数。通常 NumPy 使用 float64,也就是 8 字节的浮点数。所以你可以通过使用更小的数据类型来创建一个更小的数组(从而节省内存)。

  • 如果你仍然内存不足,你可以使用np.memmap 来创建一个基于文件的数组,而不是基于内存的数组。其他类似的选择包括使用h5pypytables 来创建 hdf5 文件。

撰写回答