多线程加载多个npz文件

2021-04-11 14:39:43 发布

您现在位置:Python中文网/ 问答频道 /正文

我有几个.npz文件。所有.npz文件结构相同:每个结构只包含两个变量,变量名总是相同的。现在,我只需遍历所有.npz文件,检索两个变量值并将它们附加到某个全局变量中:

# Let's assume there are 100 npz files
x_train = []
y_train = []
for npz_file_number in range(100):
    data = dict(np.load('{0:04d}.npz'.format(npz_file_number)))
    x_train.append(data['x'])
    y_train.append(data['y'])

这需要一段时间,而瓶颈是CPU。xy变量附加到x_train和{}变量的顺序无关紧要。在

有没有办法在多线程中加载多个.npz文件?在

1条回答
网友
1楼 ·

我被@Brent Washburne的评论吓了一跳,决定自己试试。我认为一般的问题有两个方面:

首先,读取数据通常受IO限制,因此编写多线程代码通常不会获得高性能收益。其次,由于python语言本身的设计,在python中实现共享内存并行化本身就很困难。与本机c相比,开销要大得多

但让我们看看我们能做些什么。在

# some imports
import numpy as np
import glob
from multiprocessing import Pool
import os

# creating some temporary data
tmp_dir = os.path.join('tmp', 'nptest')
if not os.path.exists(tmp_dir):
    os.makedirs(tmp_dir)
    for i in range(100):
        x = np.random.rand(10000, 50)
        file_path = os.path.join(tmp_dir, '%05d.npz' % i)
        np.savez_compressed(file_path, x=x)

def read_x(path):
    with np.load(path) as data:
        return data["x"]

def serial_read(files):
    x_list = list(map(read_x, files))
    return x_list

def parallel_read(files):
    with Pool() as pool:
        x_list = pool.map(read_x, files)
    return x_list

好吧,准备够了。我们来看看时间安排。在

^{pr2}$

实际上看起来是个不错的加速。我正在使用两个真正的和两个超线程核心。在


要同时运行和计时所有内容,可以执行以下脚本:

from __future__ import print_function
from __future__ import division

# some imports
import numpy as np
import glob
import sys
import multiprocessing
import os
import timeit

# creating some temporary data
tmp_dir = os.path.join('tmp', 'nptest')
if not os.path.exists(tmp_dir):
    os.makedirs(tmp_dir)
    for i in range(100):
        x = np.random.rand(10000, 50)
        file_path = os.path.join(tmp_dir, '%05d.npz' % i)
        np.savez_compressed(file_path, x=x)

def read_x(path):
    data = dict(np.load(path))
    return data['x']

def serial_read(files):
    x_list = list(map(read_x, files))
    return x_list

def parallel_read(files):
    pool = multiprocessing.Pool(processes=4)
    x_list = pool.map(read_x, files)
    return x_list


files = glob.glob(os.path.join(tmp_dir, '*.npz'))
#files = files[0:5] # to test on a subset of the npz files

# Timing:
timeit_runs = 5

timer = timeit.Timer(lambda: serial_read(files))
print('serial_read: {0:.4f} seconds averaged over {1} runs'
      .format(timer.timeit(number=timeit_runs) / timeit_runs,
      timeit_runs))
# 1 loops, best of 3: 7.04 s per loop

timer = timeit.Timer(lambda: parallel_read(files))
print('parallel_read: {0:.4f} seconds averaged over {1} runs'
      .format(timer.timeit(number=timeit_runs) / timeit_runs,
      timeit_runs))
# 1 loops, best of 3: 3.56 s per loop

# Examples of use:
x = serial_read(files)
print('len(x): {0}'.format(len(x))) # len(x): 100
print('len(x[0]): {0}'.format(len(x[0]))) # len(x[0]): 10000
print('len(x[0][0]): {0}'.format(len(x[0][0]))) # len(x[0]): 10000
print('x[0][0]: {0}'.format(x[0][0])) # len(x[0]): 10000
print('x[0].nbytes: {0} MB'.format(x[0].nbytes / 1e6)) # 4.0 MB

相关问题