如何使发电机螺纹安全?

2024-05-08 16:23:19 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个发电机,看起来像这样:

def data_generator(data_file, index_list,....):
      orig_index_list = index_list
    while True:
        x_list = list()
        y_list = list()
        if patch_shape:
            index_list = create_patch_index_list(orig_index_list, data_file, patch_shape,
                                                 patch_overlap, patch_start_offset,pred_specific=pred_specific)
        else:
            index_list = copy.copy(orig_index_list)

        while len(index_list) > 0:
            index = index_list.pop()
            add_data(x_list, y_list, data_file, index, augment=augment, augment_flip=augment_flip,
                     augment_distortion_factor=augment_distortion_factor, patch_shape=patch_shape,
                     skip_blank=skip_blank, permute=permute)
            if len(x_list) == batch_size or (len(index_list) == 0 and len(x_list) > 0):
                yield convert_data(x_list, y_list, n_labels=n_labels, labels=labels, num_model=num_model,overlap_label=overlap_label)
                x_list = list()
                y_list = list()

我的数据集大小为55GB,存储为.h5文件(data.h5)。读取数据时速度非常慢。一个历元需要7000秒,我在大约6个历元后得到一个分段错误

我想如果我设置multi_processing = Falseworkers > 1,它将加快读取数据的速度:

model.fit(multi_processing = False, workers = 8)

但当我这样做时,我会得到以下错误:

RuntimeError: Your generator is NOT thread-safe. Keras requires a thread-safe generator when use_multiprocessing=False, workers > 1.

有没有办法使我的生成器线程安全?或者有没有其他有效的方法来生成这些数据


Tags: falsedataindexlabelsmodellengeneratorlist
1条回答
网友
1楼 · 发布于 2024-05-08 16:23:19

我相信我在上面的评论中引用的LockedIterator类是不正确的,应该按照下面的示例进行编码:

import threading

class LockedIterator(object):
    def __init__(self, it):
        self.lock = threading.Lock()
        self.it = iter(it)

    def __iter__(self): return self

    def __next__(self):
        with self.lock:
            return self.it.__next__()
            
def gen():
    for x in range(10):
        yield x

new_gen = LockedIterator(gen())

def worker(g):
    for x in g:
        print(x, flush=True)

t1 = threading.Thread(target=worker, args=(new_gen,))
t2 = threading.Thread(target=worker, args=(new_gen,))
t1.start()
t2.start()
t1.join()
t2.join()

印刷品:

0
1
23

4
5
6
7
8
9

如果您想保证打印输出每行打印一个值,那么我们还需要向每个线程传递一个threading.Lock实例,并在该锁的控制下发出print语句,以便打印被序列化

相关问题 更多 >