如果序列长度大于1,如何从LSTM模型中采样文本?

2024-05-16 01:48:28 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图使用model.predict(...)从我的LSTM模型中抽取文本样本。我的模型的序列长度是3。在培训的上下文中,这不是问题,我将第一个sequence_length标记输入到我的模型中。你知道吗

但是,在采样/生成文本的上下文中,我还需要sequence_length标记来预测下一个标记。你知道吗

我已经这样做了,在生成的标记的开始,总是有一个标记指示一个句子的开始。如果我的模型是vanilla RNN而不是LSTM,这就解决了这个问题,因为vanilla RNN只基于前一个标记预测当前标记,而不是基于前一个标记s)。在vanillarnn中,通过向所有输入附加一个开始标记,它可以作为第一次迭代的前一个标记。你知道吗

我的研究/尝试:

  • vanilla RNNhere的实现能够对文本进行采样,但是由于上面提到的限制,无法将其采样方法应用于LSTM。

  • 我尝试用小于sequence_length的输入数组调用model.predict(...)。如我所料,model.predict(...)要求输入数组长度等于sequence_length

如果没有足够的先前标记输入到模型中,我如何允许我的模型对给定的先前标记少于采样开始时所需的标记的数据进行采样?


Tags: 标记模型文本model序列数组lengthpredict