def train():
# Model
model = Model()
# Loss, Optimizer
global_step = tf.Variable(1, dtype=tf.int32, trainable=False, name='global_step')
loss_fn = model.loss()
optimizer = tf.train.AdamOptimizer(learning_rate=TrainConfig.LR).minimize(loss_fn, global_step=global_step)
# Summaries
summary_op = summaries(model, loss_fn)
with tf.Session(config=TrainConfig.session_conf) as sess:
# Initialized, Load state
sess.run(tf.global_variables_initializer())
model.load_state(sess, TrainConfig.CKPT_PATH)
writer = tf.summary.FileWriter(TrainConfig.GRAPH_PATH, sess.graph)
# Input source
data = Data(TrainConfig.DATA_PATH)
loss = Diff()
for step in xrange(global_step.eval(), TrainConfig.FINAL_STEP):
mixed_wav, src1_wav, src2_wav, _ = data.next_wavs(TrainConfig.SECONDS, TrainConfig.NUM_WAVFILE, step)
mixed_spec = to_spectrogram(mixed_wav)
mixed_mag = get_magnitude(mixed_spec)
src1_spec, src2_spec = to_spectrogram(src1_wav), to_spectrogram(src2_wav)
src1_mag, src2_mag = get_magnitude(src1_spec), get_magnitude(src2_spec)
src1_batch, _ = model.spec_to_batch(src1_mag)
src2_batch, _ = model.spec_to_batch(src2_mag)
mixed_batch, _ = model.spec_to_batch(mixed_mag)
# Initializae our callback.
#early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.5)
l, _, summary = sess.run([loss_fn, optimizer, summary_op],
feed_dict={model.x_mixed: mixed_batch, model.y_src1: src1_batch,
model.y_src2: src2_batch})
loss.update(l)
print('step-{}\td_loss={:2.2f}\tloss={}'.format(step, loss.diff * 100, loss.value))
writer.add_summary(summary, global_step=step)
# Save state
if step % TrainConfig.CKPT_STEP == 0:
tf.train.Saver().save(sess, TrainConfig.CKPT_PATH + '/checkpoint', global_step=step)
writer.close()
我有一个神经网络代码,可以把音乐和.wav文件中的声音分开。 如何引入一种提前停车算法来停止列车区段?我看到一个关于ValidationMonitor的项目。有人能帮我吗?
以下是我对u可以适应的早期停止的实现:
早期停止可以应用于训练过程的某些阶段,例如在每个阶段的末尾。具体地说,在我的例子中,我在每个阶段监视测试(验证)丢失,并且在测试丢失在
20
个阶段(self.require_improvement= 20
)之后没有改善,训练被中断。您可以将max epochs设置为10000或20000或任何您想要的值(
self.max_epochs = 10000
)。以下是我的训练功能,我使用提前停车:
定义序列(自):
我们可以在这里恢复重要代码:
希望它能帮助某人:)。
因为TensorFlow版本
r1.10
中的估计器API可以使用早期停止挂钩(参见github)。例如
tf.contrib.estimator.stop_if_no_decrease_hook
(请参见docs)ValidationMonitor标记为已弃用。不建议这样做。但你还是可以用的。 下面是如何创建一个示例:
你可以自己实现,这里是我的实现:
相关问题 更多 >
编程相关推荐