对NumPy数组进行懒惰评估的迭代

5 投票
3 回答
3803 浏览
提问于 2025-04-16 02:13

我有一个Python程序,它处理的NumPy数组比较大(几百兆),这些数组保存在磁盘上的pickle文件里(每个文件大约100MB)。当我想对数据进行查询时,我会通过pickle把整个数组加载到内存中,然后再进行查询(从Python程序的角度来看,整个数组都在内存里,即使操作系统可能会把它换出)。我这样做主要是因为我认为在NumPy数组上使用向量化操作会比用for循环逐个处理每个项目快得多。

我在一个有内存限制的网络服务器上运行这个程序,结果很快就碰到了内存的瓶颈。我有很多不同类型的查询需要在数据上运行,所以如果我写一段“分块”代码,从不同的pickle文件中加载数据的一部分,处理完后再继续下一个部分,这样会增加很多复杂性。能让这个“分块”过程对处理这些大数组的任何函数来说都是透明的,肯定是更好的选择。

看起来理想的解决方案是使用类似生成器的东西,定期从磁盘加载一块数据,然后一个一个地把数组的值传出来。这样可以大大减少程序所需的内存,而不需要对每个查询函数做额外的工作。这样做有可能吗?

3 个回答

2

看起来理想的解决方案应该是一个生成器,它可以定期从磁盘加载一块数据,然后一个一个地把数组里的值传出来。这样可以大大减少程序所需的内存,而且不需要对每个查询函数做额外的工作。这样做有可能吗?

可以,但不能把数组全部放在磁盘上的一个文件里,因为那种文件格式(pickle)并不支持“增量反序列化”。

你可以在同一个打开的文件中写入多个pickle文件,一个接一个(使用dump,而不是dumps),然后“懒惰评估器”在迭代时每次只需要使用pickle.load

示例代码(Python 3.1 -- 在2.x版本中,你需要用cPickle替代pickle,并且协议要用-1,等等;当然;-):

>>> import pickle
>>> lol = [range(i) for i in range(5)]
>>> fp = open('/tmp/bah.dat', 'wb')
>>> for subl in lol: pickle.dump(subl, fp)
... 
>>> fp.close()
>>> fp = open('/tmp/bah.dat', 'rb')
>>> def lazy(fp):
...   while True:
...     try: yield pickle.load(fp)
...     except EOFError: break
... 
>>> list(lazy(fp))
[range(0, 0), range(0, 1), range(0, 2), range(0, 3), range(0, 4)]
>>> fp.close()
4

NumPy的内存映射数据结构memmap)在这里可能是个不错的选择。

你可以从磁盘上的一个二进制文件中访问你的NumPy数组,而不需要一次性把整个文件都加载到内存中。

(注意,我认为NumPy的memmap对象和Python的memmap对象是不一样的——特别是,NumPy的memmap更像数组,而Python的memmap更像文件。)

这个方法的参数列表是:

A = NP.memmap(filename, dtype, mode, shape, order='C')

所有的参数都很简单明了(也就是说,它们的意思和NumPy其他地方用的一样),除了'order',这个参数指的是ndarray内存布局的顺序。我相信默认值是'C',而另一个选项是'F',代表Fortran——这两种选项分别表示行优先和列优先的顺序。

这两个方法是:

flush(这个方法会把你对数组所做的任何更改写入磁盘);以及

close(这个方法会把数据写入memmap数组,或者更准确地说,是写入一个类似数组的内存映射,指向存储在磁盘上的数据)

使用示例:

import numpy as NP
from tempfile import mkdtemp
import os.path as PH

my_data = NP.random.randint(10, 100, 10000).reshape(1000, 10)
my_data = NP.array(my_data, dtype="float")

fname = PH.join(mkdtemp(), 'tempfile.dat')

mm_obj = NP.memmap(fname, dtype="float32", mode="w+", shape=1000, 10)

# now write the data to the memmap array:
mm_obj[:] = data[:]

# reload the memmap:
mm_obj = NP.memmap(fname, dtype="float32", mode="r", shape=(1000, 10))

# verify that it's there!:
print(mm_obj[:20,:])
9

PyTables 是一个用来管理分层数据集的工具包。它的设计目的是为了帮助你解决这个问题。

撰写回答