Tensorflow 冻结随机数据集
我在训练模型的时候,使用了一个经过打乱的数据集,这个数据集是这样创建的:
train_dataset = train_dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
但是在训练之后,我想用一些简单的测试函数来检查预测结果。
def detect_wrong_ts(model,ds,label_encoder):
# Take first elements batch
ds = ds.take(1)
# Do data preprocessing
# it converts strings from initial dataset to indexes "word1 word2" -> [ 1, 2 ]
mapper = lambda in1,in2,out : (model.preprocessor(in1),out)
train = ds.map(mapper)
# Do predictions on postprocessed dataset
x = model.predict(train)
# Now I try to print results with the information based
# on the
# 1 ) initial dataset
# 2 ) postprocessed dataset
# 3 ) prediction
for batch,outp in zip(ds,train):
inp = batch[0]
outp_code = outp[0]
for i,inp in enumerate(inp):
inp_str = inp.numpy().decode("utf-8")
inp_codes = model.preprocessor([inp_str])[0].numpy()
postprocess_codes = outp_code[i].numpy()
print(
f"#{i} {inp_str=} {inp_codes=} {postprocess_codes=}",
)
可我发现这两个数据集的结果是打乱的,彼此不匹配。原因很明显:每次启动新的迭代器时,数据就会重新打乱。
我可以关闭完全打乱的功能,但这样做对于整个流程来说比较复杂。
我在考虑有没有办法可以“冻结”这个数据集,让它在每次迭代时都能产生相同的结果。
这样解决这个问题可行吗?
1 个回答
0
解决方案很简单:
使用 cache()
函数
所以
ds = ds.take(1).cache()
这将使所有新的迭代器保持稳定