为什么np.genfromtxt()对于大型数据集,最初会占用大量内存?

2024-05-17 16:58:10 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个数据集,有450000列和450行-都是数值。我用np.genfromtxt()函数将数据集加载到NumPy数组中:

# The skip_header skips over the column names, which is the first row in the file
train = np.genfromtxt('train_data.csv', delimiter=',', skip_header=1)

train_labels = train[:, -1].astype(int)
train_features = train[:, :-1]

最初加载数据集时使用RAM-15GB以上的函数。但是,在该函数运行完毕后,它的RAM使用量仅为2-3GB。为什么np.genfromtxt()最初会占用这么多内存?在


Tags: the数据函数numpynptrain数组ram
2条回答

{在kasvd中提出了一个好的解决方案。这个答案中的iter_loadtxt()解决方案是我的问题的完美解决方案:

def iter_loadtxt(filename, delimiter=',', skiprows=0, dtype=float):
    def iter_func():
        with open(filename, 'r') as infile:
            for _ in range(skiprows):
                next(infile)
            for line in infile:
                line = line.rstrip().split(delimiter)
                for item in line:
                    yield dtype(item)
        iter_loadtxt.rowlength = len(line)

    data = np.fromiter(iter_func(), dtype=dtype)
    data = data.reshape((-1, iter_loadtxt.rowlength))
    return data

genfromtxt()占用这么多内存的原因是它在解析数据文件时没有将数据存储在高效的NumPy数组中,因此在NumPy解析我的大数据文件时内存使用过多。在

如果您提前知道数组的大小,则可以通过在解析时将每一行加载到目标数组中来节省时间和空间。在

例如:

In [173]: txt="""1,2,3,4,5,6,7,8,9,10
     ...: 2,3,4,5,6,7,8,9,10,11
     ...: 3,4,5,6,7,8,9,10,11,12
     ...: """

In [174]: np.genfromtxt(txt.splitlines(),dtype=int,delimiter=',',encoding=None)
Out[174]: 
array([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
       [ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
       [ 3,  4,  5,  6,  7,  8,  9, 10, 11, 12]])

具有更简单的解析功能:

^{pr2}$

out[i,:] = line.split(',')将字符串列表加载到数字数据类型数组中会强制进行转换,与np.array(line..., dtype=int)相同。在

In [179]: timeit np.genfromtxt(txt.splitlines(),dtype=int,delimiter=',',encoding
     ...: =None)
266 µs ± 427 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [180]: timeit foo(txt.splitlines(),(3,10))
19.2 µs ± 169 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

更简单、直接的解析器要快得多。在

但是,如果我尝试loadtxtgenfromtxt使用的简化版本:

In [184]: def bar(txt):
     ...:     alist=[]
     ...:     for i,line in enumerate(txt):
     ...:        alist.append(line.split(','))
     ...:     return np.array(alist, dtype=int)
     ...: 
     ...: 
In [185]: bar(txt.splitlines())
Out[185]: 
array([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
       [ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
       [ 3,  4,  5,  6,  7,  8,  9, 10, 11, 12]])
In [186]: timeit bar(txt.splitlines())
13 µs ± 20.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

对于这个小案子,它更快。genfromtxt必须有大量的解析开销。这是一个小样本,所以内存消耗并不重要。在


为了完整起见,loadtxt

In [187]: np.loadtxt(txt.splitlines(),dtype=int,delimiter=',')
Out[187]: 
array([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
       [ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
       [ 3,  4,  5,  6,  7,  8,  9, 10, 11, 12]])
In [188]: timeit np.loadtxt(txt.splitlines(),dtype=int,delimiter=',')
103 µs ± 50.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

使用fromiter

In [206]: def g(txt):
     ...:     for row in txt:
     ...:         for item in row.split(','):
     ...:             yield item
In [209]: np.fromiter(g(txt.splitlines()),dtype=int).reshape(3,10)
Out[209]: 
array([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
       [ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
       [ 3,  4,  5,  6,  7,  8,  9, 10, 11, 12]])
In [210]: timeit np.fromiter(g(txt.splitlines()),dtype=int).reshape(3,10)
12.3 µs ± 21.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

相关问题 更多 >