Tensorflow LSTM编码器测试代码尺寸不匹配

2024-04-28 08:47:16 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在使用tensorflow编写一个lstm编码器。然后我写了一个测试代码来看看我的代码是如何工作的。这是我的密码:

import tensorflow as tf

class Encoder(object):
    def __init__(self, state_size,vocab_dim, FLAGS):
        self.state_size = state_size
        self.FLAGS = FLAGS
        self.vocab_dim = vocab_dim
    # Return hidden representation HQ, HP of question and paragraph respectively
    def LSTMpreprocessing(self,paragraph,question, paragraph_length,question_length):
        #Encode Question
        with tf.variable_scope("Q_encode"):
            cell = tf.nn.rnn_cell.BasicLSTMCell(self.state_size)
            HQ, _ = tf.nn.dynamic_rnn(cell,question,sequence_length = question_length, dtype = tf.float32)

        #Encode Paragraph
        with tf.variable_scope("P_encode"):
            cell = tf.nn.rnn_cell.BasicLSTMCell(self.state_size)
            HP, _ = tf.nn.dynamic_rnn(cell,paragraph,sequence_length = paragraph_length, dtype = tf.float32)
        return HQ,HP

目前,我正在尝试检查什么是后处理返回。为此,我编写了以下测试代码:

def main(_):
    paragraph_placeholder = tf.placeholder(tf.int32, (None, 4), name="paragraph_placeholder")
    question_placeholder = tf.placeholder(tf.int32, (None, 3), name="question_placeholder")
    paragraph_length = tf.placeholder(tf.int32, (None), name="paragraph_length")
    question_length = tf.placeholder(tf.int32, (None), name="question_length")
    encoder = Encoder(4,3,None)

    paragraph = [[0,1],[1,0],[0,2],[5,3]]
    question = [[3,3],[5,5],[1,1]]
    func = encoder.LSTMpreprocessing(paragraph_placeholder,question_placeholder,paragraph_length,question_length)

    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        HQ,HP = sess.run(func, feed_dict = {paragraph_placeholder :paragraph, question_placeholder : question,paragraph_length : 4, question_length : 3}) 
    print(HQ.get_shape().as_list())
    print(HP.get_shape().as_list())

运行上述测试代码时,出现以下错误:

    ValueError: Dimension must be 2 but is 3 for 
'Q_encode/transpose' (op: 'Transpose') with input shapes: [?,3], [3].

作为tensorflow的新手,我完全不知道自己做错了什么。有人能帮我指出我犯的错误吗?你知道吗


Tags: selfnonesizetfaswithcellnn