我正在尝试读取一个TFRecord文件,但是每当我试图计算刚从文件数据生成的张量时,我的终端就会崩溃。我正在研究RNNs(递归神经网络),所以我尝试使用序列数据来完成我的工作。我已经从here复制了大部分代码,但是我自己添加了一些内容来尝试包含一个TFRecordReader。在
import tensorflow as tf
import tempfile
import os
sequences = [[1, 2, 3], [4, 5, 1], [1, 2]]
label_sequences = [[0, 1, 0], [1, 0, 0], [1, 1]]
def make_example(sequence, labels):
# The object we return
ex = tf.train.SequenceExample()
# A non-sequential feature of our example
sequence_length = len(sequence)
ex.context.feature["length"].int64_list.value.append(sequence_length)
# Feature lists for the two sequential features of our example
fl_tokens = ex.feature_lists.feature_list["tokens"]
fl_labels = ex.feature_lists.feature_list["labels"]
for token, label in zip(sequence, labels):
fl_tokens.feature.add().int64_list.value.append(token)
fl_labels.feature.add().int64_list.value.append(label)
return ex
def parse_example(filename_queue):
reader = tf.TFRecordReader()
_, example = reader.read(filename_queue)
print(example)
#example = filename_queue
context_features = {
"length": tf.FixedLenFeature([], dtype=tf.int64)
}
sequence_features = {
"tokens": tf.FixedLenSequenceFeature([], dtype=tf.int64),
"labels": tf.FixedLenSequenceFeature([], dtype=tf.int64)
}
# Parse the example
context_parsed, sequence_parsed = tf.parse_single_sequence_example(
serialized=example,
context_features=context_features,
sequence_features=sequence_features
)
return context_parsed, sequence_parsed
if __name__ == '__main__':
generated_file = ""
#################################
# Write all examples into a TFRecords file
#################################
with tempfile.NamedTemporaryFile(dir=".", delete=False) as fp:
writer = tf.python_io.TFRecordWriter(fp.name)
generated_file = fp.name
for sequence, label_sequence in zip(sequences, label_sequences):
ex = make_example(sequence, label_sequence)
writer.write(ex.SerializeToString())
writer.close()
#################################
# Read contents of TFRecord file
#################################
filename = os.path.join(os.getcwd(), generated_file)
print(filename)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
filename_queue = tf.train.string_input_producer([filename])
context_parsed, sequence_parsed = parse_example(filename_queue)
print(context_parsed)
print(sequence_parsed)
# THIS IS WHERE I NEED HELP
# terminal freezes if the following lines are uncommented
#print(context_parsed["length"].eval())
#print(sequence_parsed["tokens"].eval())
#print(sequence_parsed["labels"].eval())
我对tensorflow还是比较熟悉的,所以我希望能解释一下我做错了什么(不仅仅是代码修复),这样我以后就不会犯类似的错误了。谢谢您!在
—
编辑:好的,我把上面的内容改了如下:
^{pr2}$但这给了我以下错误:
错误:tensorflow:QueueRunner中出现异常:试图使用关闭的会话。
我不知道这是怎么回事。根据directions on tensorflow,雅罗斯拉夫的评论是正确的,但我不确定我做错了什么。我试过几种不同格式的代码,但我不确定我遗漏了什么。在
—
编辑2:好吧,我想好了。显然,我需要生成一个协调器,并将其传递给队列运行者以使其工作。以下是我的最终代码:
with tf.Session() as sess:
coord = tf.train.Coordinator()
tf.train.start_queue_runners(coord=coord)
context_parsed, sequence_parsed = parse_example(filename_queue)
for i in range(3):
v1,v2,v3 = sess.run([context_parsed["length"], sequence_parsed["tokens"], sequence_parsed["labels"]])
print(v1,v2,v3)
目前没有回答
相关问题 更多 >
编程相关推荐