关于在Keras fit\u gen中使用python数据生成器的困惑

2024-05-23 19:11:18 发布

您现在位置:Python中文网/ 问答频道 /正文

一些问题和教程如下:

  1. Why is an iterable object not an iterator?
  2. Generator "TypeError: 'generator' object is not an iterator"

建议keras的数据生成器应该是一个类,其中包含\uuu iter\uuuuuuuuuuuu\uuuuu next\uuuuuuuuuuuuuu方法。你知道吗

其他一些教程如:

  1. https://keunwoochoi.wordpress.com/2017/08/24/tip-fit_generator-in-keras-how-to-parallelise-correctly/
  2. https://www.altumintelligence.com/articles/a/Time-Series-Prediction-Using-LSTM-Deep-Neural-Networks

将普通python函数与提供数据的yield语句一起使用。虽然我在上面的第二个教程之后成功地在LSTM网络中使用了收益率,但我无法在卷积网络中使用正常收益率函数,并且在fitèu generator中得到以下错误:

'method' object is not an iterator

我没有尝试过使用\uuuuunext\uuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuu。有人能帮我弄清楚什么时候使用哪种技术吗?一个函数“产生下一个样本”和一个类的下一个样本有什么区别?你知道吗

我的工作代码使用yield: https://github.com/KashyapCKotak/Multidimensional-Stock-Price-Prediction/blob/master/StockTF1_4Sequential.ipynb

我当前使用yield的数据生成器函数(编辑:在Daniel Möller建议的修复后工作):

def train_images_generator(self):
    for epoch in range(0, self.epochs):
      print("Current Epoch:",epoch)
      cnt = 0
      if epoch > 2000:
        learning_rate = 1e-5

      for ind in np.random.permutation(len(self.train_ids)):
        print("provided image with id:",ind)
        #get the input image and target/ground truth image based on ind
        raw = rawpy.imread(in_path)
        input_images = np.expand_dims(pack_raw(raw), axis=0) * ratio # pack the bayer image in 4 channels of RGBG

        gt_raw = rawpy.imread(gt_path)
        im = gt_raw.postprocess(use_camera_wb=True,
                      half_size=False,
                      no_auto_bright=True, output_bps=16)
        gt_images = np.expand_dims(np.float32(im / 65535.0),axis=0) # divide by 65535 to normalise (scale between 0 and 1)

        # crop

        H = input_images.shape[1] # get the image height (number of rows)
        W = input_images.shape[2] # get the image width (number of columns)

        xx = np.random.randint(0, W - ps) # get a random number in W-ps (W-512)
        yy = np.random.randint(0, H - ps) # get a random number in H-ps (H-512)
        input_patch = input_images[:, yy:yy + ps, xx:xx + ps, :]
        gt_patch = gt_images[:, yy * 2:yy * 2 + ps * 2, xx * 2:xx * 2 + ps * 2, :]

        if np.random.randint(2) == 1:  # random flip for rows
          input_patch = np.flip(input_patch, axis=1)
          gt_patch = np.flip(gt_patch, axis=1)
        if np.random.randint(2) == 1:  # random flip for columns
          input_patch = np.flip(input_patch, axis=2)
          gt_patch = np.flip(gt_patch, axis=2)
        if np.random.randint(2) == 1:  # random transpose
          input_patch = np.transpose(input_patch, (0, 2, 1, 3))
          gt_patch = np.transpose(gt_patch, (0, 2, 1, 3))\

        input_patch = np.minimum(input_patch, 1.0)

        yield (input_patch,gt_patch)

我如何使用它:

model.fit_generator(
  generator=data.train_images_generator(),
  steps_per_epoch=steps_per_epoch,
  epochs=epochs,
  callbacks=callbacks,
  max_queue_size=50
  #workers=0

()


Tags: inimagegtinputgetrawnprandom
1条回答
网友
1楼 · 发布于 2024-05-23 19:11:18

仔细看看'method'这个词,我发现您并没有“调用”您的生成器(您并没有创建它)。你知道吗

您只传递函数/方法。你知道吗

假设你有:

def generator(...):
    ...
    yield x, y

而不是像这样:

model.fit_generator(generator)

你应该这样做:

model.fit_generator(generator(...))

发生器或序列

使用生成器(带有yield的函数)和keras.utils.Sequence的函数有什么区别?你知道吗

当使用生成器时,训练将按照确切的循环顺序进行,并且不知道何时完成。所以。你知道吗

带发电机:

  • 无法洗牌批处理,因为它将始终遵循循环的顺序
  • 必须通知steps_per_epoch,因为Keras无法知道生成器何时完成(Keras的生成器必须是无限的)
  • 如果使用多处理,系统可能无法正确处理批处理,因为无法知道哪个进程将在其他进程之前启动或完成。你知道吗

Sequence

  • 你可以控制发电机的长度。Keras自动知道批的数量
  • 您可以控制批的索引,以便Keras可以洗牌批。你知道吗
  • 你可以按你想要的批次取多少次(你不必按顺序取批次)
  • 多处理可以使用索引来确保批处理最终不会混合在一起。你知道吗

相关问题 更多 >