我使用contrib seq2seq API创建了Dataset API管道和seq2seq编码器-解码器模型
我正在创建两个不同的解码器(共享相同的权重):
但是,我不能使用我的模型,因为我要调用解码器函数两次:
调用函数两次,我复制了一些变量
下面是我的解码器函数,它创建训练和推理解码器:
def decoder(target, hidden_state, encoder_outputs):
with tf.name_scope("decoder"):
# ... embedding the targets
decoder_inputs = embeddings(target)
decoder_gru_cell = tf.nn.rnn_cell.GRUCell(dec_units, name="gru_cell")
# Here I create the training decoder part
with tf.variable_scope("decoder"):
training_helper = tf.contrib.seq2seq.TrainingHelper(decoder_inputs, max_length)
training_decoder = tf.contrib.seq2seq.BasicDecoder(decoder_gru_cell, training_helper, hidden_state)
training_decoder_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(training_decoder, max_length)
# And here I create the inference decoder part
with tf.variable_scope("decoder", reuse=True):
inference_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(...)
inference_decoder = tf.contrib.seq2seq.BasicDecoder(decoder_gru_cell, inference_helper, hidden_state)
inference_decoder_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(inference_decoder, max_length)
return training_decoder_outputs, inference_decoder_outputs
在这里我创建了我的模型:
def seq2_seq2_model(values, labels):
encoder_outputs, hidden_state = encoder(values)
training_decoder, inference_decoder = decoder(labels, hidden_state, encoder_outputs)
return training_decoder, inference_decoder
以下是我的数据集,我将其分为训练部分和测试部分(大小n\u测试):
values_dataset = tf.data.Dataset.from_tensor_slices(values)
labels_dataset = tf.data.Dataset.from_tensor_slices(labels)
X_Y_dataset = tf.data.Dataset.zip((features_dataset, caption_dataset))
X_Y_test = X_Y_dataset.take(n_test).batch(n_test)
X_Y_train = X_Y_dataset.skip(n_test).batch(batch_size)
test_iterator = X_Y_test.make_initializable_iterator()
x_y_test_next = test_iterator.get_next()
train_iterator = X_Y_train.make_initializable_iterator()
x_y_train_next = train_iterator.get_next()
最后,我通过调用seq2\ seq2\ model来构建我的模型:
training_decoder_outputs, _ = seq2_seq2_model(*x_y_train_next)
_, inference_decoder_outputs = seq2_seq2_model(*x_y_test_next)
错误来了,因为我创建了两次可变解码器\u gru\u单元格
ValueError: Variable decoder/decoder/attention_wrapper/gru_cell/gates/kernel already exists, disallowed.
我可以为复制的变量创建一个全局变量,但这似乎是解决问题的一种肮脏的方法。另外,我展示的代码是我的一个简化版本:我必须创建几个全局变量
我终于找到了。 关键是使用一个可重新初始化的迭代器,以便切换数据集作为输入源
然后我们可以在一个步骤中创建解码器:
最后,我们使用train_init_op和test_init_op来指定是要使用train数据集还是test_数据集:
相关问题 更多 >
编程相关推荐