Tensorflow 冻结随机数据集

0 投票
1 回答
19 浏览
提问于 2025-04-14 17:14

我在训练模型的时候,使用了一个经过打乱的数据集,这个数据集是这样创建的:

train_dataset = train_dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

但是在训练之后,我想用一些简单的测试函数来检查预测结果。

def detect_wrong_ts(model,ds,label_encoder):
    # Take first elements batch 
    ds = ds.take(1)
    # Do data preprocessing 
    # it converts strings from initial dataset to indexes "word1 word2" -> [ 1, 2 ]
    mapper = lambda in1,in2,out : (model.preprocessor(in1),out)
    train = ds.map(mapper)
    
    # Do predictions on postprocessed dataset
    x = model.predict(train)

    # Now I try to print results with the information based 
    # on the 
    # 1 ) initial dataset 
    # 2 ) postprocessed dataset 
    # 3 ) prediction
    for batch,outp in zip(ds,train):
        inp = batch[0]
        outp_code = outp[0]
        for i,inp in enumerate(inp):
            inp_str = inp.numpy().decode("utf-8")
            inp_codes = model.preprocessor([inp_str])[0].numpy()
            postprocess_codes = outp_code[i].numpy()
            print(
                  f"#{i} {inp_str=} {inp_codes=} {postprocess_codes=}",
            )

可我发现这两个数据集的结果是打乱的,彼此不匹配。原因很明显:每次启动新的迭代器时,数据就会重新打乱。

我可以关闭完全打乱的功能,但这样做对于整个流程来说比较复杂。

我在考虑有没有办法可以“冻结”这个数据集,让它在每次迭代时都能产生相同的结果。

这样解决这个问题可行吗?

1 个回答

0

解决方案很简单:

使用 cache() 函数

所以

ds = ds.take(1).cache()

这将使所有新的迭代器保持稳定

撰写回答