擅长:python、mysql、java
<p>使用如下数据生成器:</p>
<pre><code>from keras.preprocessing.image import ImageDataGenerator
import keras
import tensorflow as tf
import numpy as np
class JoinedGen(tf.keras.utils.Sequence):
def __init__(self, input_gen, target_gen):
self.input_gen = input_gen
self.target_gen = target_gen
assert len(input_gen) == len(target_gen)
def __len__(self):
return len(self.input_gen)
def __getitem__(self, i):
x = self.input_gen[i]
y = self.target_gen[i]
return x, y
def on_epoch_end(self):
self.input_gen.on_epoch_end()
self.target_gen.on_epoch_end()
self.input_gen.index_array = self.target_gen.index_array
train_generator = JoinedGen(image_generator, mask_generator)
model.fit(train_generator, epochs=500, verbose=1, callbacks=[mc, neptune_cbk])
</code></pre>
<p>如果你使用<code>list</code>,它将在开始训练之前生成所有样本,这将破坏你的记忆</p>