“model.fit_generator”中使用的参数“max_q_size”是什么?

2024-04-29 10:02:22 发布

您现在位置:Python中文网/ 问答频道 /正文

我构建了一个简单的生成器,生成一个只有inputstargets列表中单个项的tuple(inputs, targets)。基本上,它是对数据集进行爬网,一次一个样本项。

我把发电机送入:

  model.fit_generator(my_generator(),
                      nb_epoch=10,
                      samples_per_epoch=1,
                      max_q_size=1  # defaults to 10
                      )

我明白了:

  • nb_epoch是训练批将运行的次数
  • samples_per_epoch是每个历元训练的样本数

但是max_q_size的用途是什么,为什么它会默认为10?我认为使用生成器的目的是将数据集批处理成合理的块,那么为什么要添加队列呢?


Tags: 数据列表sizemodelgeneratormax发电机样本
1条回答
网友
1楼 · 发布于 2024-04-29 10:02:22

这只是定义了内部训练队列的最大大小,该队列用于“预处理”生成器中的样本。它用于生成队列

def generator_queue(generator, max_q_size=10,
                    wait_time=0.05, nb_worker=1):
    '''Builds a threading queue out of a data generator.
    Used in `fit_generator`, `evaluate_generator`, `predict_generator`.
    '''
    q = queue.Queue()
    _stop = threading.Event()

    def data_generator_task():
        while not _stop.is_set():
            try:
                if q.qsize() < max_q_size:
                    try:
                        generator_output = next(generator)
                    except ValueError:
                        continue
                    q.put(generator_output)
                else:
                    time.sleep(wait_time)
            except Exception:
                _stop.set()
                raise

    generator_threads = [threading.Thread(target=data_generator_task)
                         for _ in range(nb_worker)]

    for thread in generator_threads:
        thread.daemon = True
        thread.start()

    return q, _stop

换句话说,您有一个线程直接从生成器将队列填充到给定的最大容量,而(例如)训练例程消耗其元素(有时等待完成)

 while samples_seen < samples_per_epoch:
     generator_output = None
     while not _stop.is_set():
         if not data_gen_queue.empty():
             generator_output = data_gen_queue.get()
             break
         else:
             time.sleep(wait_time)

为什么拖欠10英镑?没有什么特别的原因,就像大多数默认值一样——这很有意义,但是您也可以使用不同的值。

这样的结构表明,作者考虑了昂贵的数据生成器,这可能需要时间来执行。例如,考虑在生成器调用中通过网络下载数据—然后预处理一些下一批数据,并并行下载下一批数据是有意义的,这样可以提高效率并对网络错误等具有鲁棒性

相关问题 更多 >