Python 训练数组过大
我有超过3000万个对象需要用作我的训练数据。我的问题很简单:当我通过不断添加的方式创建我的训练数组时,到了某个临界点,列表变得太大,Python就崩溃了。有没有什么办法可以解决这个问题?我已经尝试了几个小时,但一直没有找到答案!
创建训练数组的代码示例
training_array = []
for ...:
data = #load data from somewhere
data_array = [x for x in data] #some large array, 2-3 million objects
for item in data_array:
training_array.append(item.a + item.b)
过了一段时间,控制台上会打印出 "killed"
,然后Python就退出了。我该如何避免这种情况呢?
更具体的问题描述:
我正在尝试在一个非常非常大的数组上进行训练,但在创建这个数组时,Python就崩溃了。这个训练算法不能分批处理数据,而是需要一个完整的数组,这限制了我之前想到的解决办法。有没有其他方法可以创建这个数组,而不占用我所有的内存(如果这真的是问题所在)?
2 个回答
你可以做几件事:
分块处理数据 - 这样可以避免一次性处理一个超级大的数组,也能减少系统负担。
使用生成器来生成数据 - 生成器是“懒惰”计算的,这意味着它们不会一次性全部存在。每当你需要某个元素时,它才会被创建,而不是提前生成,这样就不会出现超级大的数组。如果你不熟悉生成器,可能会有点难理解,但网上有很多相关的资源可以帮助你。
针对你的具体问题,可以试试这个生成器:
def train_gen(data):
data_gen = (x for x in data) #The () here are important as it makes data_gen a generator as well, as opposed to a list
for item in data_gen:
yield item.a + item.b
data = #load data from somewhere
training_array = train_gen(data)
for item in training_array:
#Iterates through training_array, producing one value, then discarding it such that only one item in training_array is in memory at a time
data
是一个 Python 列表吗?如果是的话,data_array = [x for x in data]
就没必要这样做,因为这和说
data_array = list(data)
是一样的。这会复制一份
data
,这样会占用两倍的内存,但不清楚这样做有什么用。另外,你可以使用
del data
来让 Python 回收不再需要的data
占用的内存。另一方面,可能
data
是一个迭代器。如果是这样的话,你可以通过避免创建 Python 列表data_array
来节省内存。特别是,你不需要data_array
来定义training_array
。你可以用data_array = [x for x in data] #some large array, 2-3 million objects for item in data_array: training_array.append(item.a + item.b)
替换为列表推导式
training_array = [x.a + x.b for x in data]
如果你在使用 NumPy,并且最终想让
training_array
成为一个 NumPy 数组,那么你可以通过避免创建中间的 Python 列表training_array
来节省更多内存。你可以直接从data
定义 NumPy 数组training_data
:training_array = np.fromiter((x.a + x.b for x in data), dtype=...)
注意
(x.a + x.b for x in data)
是一个生成器表达式,这样就避免了如果使用列表推导式所需的更大内存。如果你知道
data
的长度,在调用np.fromiter
时添加count=...
会加快性能,因为这可以让 NumPy 预先分配合适的内存给最终的数组。你还需要指定正确的数据类型。如果
training_array
中的值是浮点数,你可以通过指定一个更小的 itemsize 的数据类型来节省内存(虽然会牺牲一些精度)。例如,dtype='float32'
会用 4 字节(即 32 位)来存储数组中的每个浮点数。通常 NumPy 使用float64
,也就是 8 字节的浮点数。所以你可以通过使用更小的数据类型来创建一个更小的数组(从而节省内存)。如果你仍然内存不足,你可以使用np.memmap 来创建一个基于文件的数组,而不是基于内存的数组。其他类似的选择包括使用h5py 或pytables 来创建 hdf5 文件。