读取TFRecord fi时无限循环

2024-04-23 11:57:18 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试读取一个TFRecord文件,但是每当我试图计算刚从文件数据生成的张量时,我的终端就会崩溃。我正在研究RNNs(递归神经网络),所以我尝试使用序列数据来完成我的工作。我已经从here复制了大部分代码,但是我自己添加了一些内容来尝试包含一个TFRecordReader。在

import tensorflow as tf
import tempfile
import os

sequences = [[1, 2, 3], [4, 5, 1], [1, 2]]
label_sequences = [[0, 1, 0], [1, 0, 0], [1, 1]]

def make_example(sequence, labels):
    # The object we return
    ex = tf.train.SequenceExample()
    # A non-sequential feature of our example
    sequence_length = len(sequence)
    ex.context.feature["length"].int64_list.value.append(sequence_length)
    # Feature lists for the two sequential features of our example
    fl_tokens = ex.feature_lists.feature_list["tokens"]
    fl_labels = ex.feature_lists.feature_list["labels"]
    for token, label in zip(sequence, labels):
        fl_tokens.feature.add().int64_list.value.append(token)
        fl_labels.feature.add().int64_list.value.append(label)
    return ex

def parse_example(filename_queue):

    reader = tf.TFRecordReader()
    _, example = reader.read(filename_queue)
    print(example)

    #example = filename_queue

    context_features = {
    "length": tf.FixedLenFeature([], dtype=tf.int64)
    }
    sequence_features = {
        "tokens": tf.FixedLenSequenceFeature([], dtype=tf.int64),
        "labels": tf.FixedLenSequenceFeature([], dtype=tf.int64)
    }

    # Parse the example
    context_parsed, sequence_parsed = tf.parse_single_sequence_example(
        serialized=example,
        context_features=context_features,
        sequence_features=sequence_features
    )

    return context_parsed, sequence_parsed

if __name__ == '__main__':
    generated_file = ""
    #################################
    # Write all examples into a TFRecords file
    #################################
    with tempfile.NamedTemporaryFile(dir=".", delete=False) as fp:
        writer = tf.python_io.TFRecordWriter(fp.name)
        generated_file = fp.name
        for sequence, label_sequence in zip(sequences, label_sequences):
            ex = make_example(sequence, label_sequence)
            writer.write(ex.SerializeToString())
        writer.close()

    #################################
    # Read contents of TFRecord file
    #################################


    filename = os.path.join(os.getcwd(), generated_file)
    print(filename)

    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        filename_queue = tf.train.string_input_producer([filename])
        context_parsed, sequence_parsed = parse_example(filename_queue)
        print(context_parsed)
        print(sequence_parsed)

        # THIS IS WHERE I NEED HELP
        # terminal freezes if the following lines are uncommented

        #print(context_parsed["length"].eval()) 
        #print(sequence_parsed["tokens"].eval()) 
        #print(sequence_parsed["labels"].eval()) 

我对tensorflow还是比较熟悉的,所以我希望能解释一下我做错了什么(不仅仅是代码修复),这样我以后就不会犯类似的错误了。谢谢您!在

编辑:好的,我把上面的内容改了如下:

^{pr2}$

但这给了我以下错误:

错误:tensorflow:QueueRunner中出现异常:试图使用关闭的会话。

我不知道这是怎么回事。根据directions on tensorflow,雅罗斯拉夫的评论是正确的,但我不确定我做错了什么。我试过几种不同格式的代码,但我不确定我遗漏了什么。在

编辑2:好吧,我想好了。显然,我需要生成一个协调器,并将其传递给队列运行者以使其工作。以下是我的最终代码:

    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(coord=coord)
        context_parsed, sequence_parsed = parse_example(filename_queue)
        for i in range(3):
            v1,v2,v3 = sess.run([context_parsed["length"], sequence_parsed["tokens"], sequence_parsed["labels"]])
            print(v1,v2,v3)

Tags: labelsqueueexampletfcontextfilenameparsedlength