如何处理动态形状?

2024-04-25 09:27:03 发布

您现在位置:Python中文网/ 问答频道 /正文

我是一个使用Tensorflow的初学者,我开始了手写识别任务。目前,我正在测试GridLSTMCell(https://www.tensorflow.org/api_docs/python/tf/contrib/rnn/GridLSTMCell),当我试图用这种类型的递归单元初始化动态RNN时,我遇到了一些问题。在

cell = rnn.GridLSTMCell(num_units=num_units, feature_size=feature_size, frequency_skip=frequency_skip, num_frequency_blocks=num_frequency_blocks, forget_bias=1, state_is_tuple=True)
outputs, _ = tf.nn.dynamic_rnn(cell, inputs, seq_len, dtype=tf.float32)

显然,这个问题是由于GridLSTMCell中存在的一个限制造成的,它只能处理静态形状的输入,但是正如您在下面看到的,在我的例子中,批处理大小可能会有所不同,因为我有一个batch_size=1的训练集和一个有100个示例的固定测试集。在

^{pr2}$

在这种情况下,我想知道是否有人有解决办法,也许有一些技巧可以定义固定批处理大小=1的输入,并且能够将模型应用到有100个示例的测试集。类似问题(Tensorflow Grid LSTM RNN TypeError)中的一些答案与TF中的另一个gridlstm实现有关。在

很抱歉我的英语很差,谢谢你的支持。在

提前谢谢!在


Tags: 示例sizetftensorflowcellnumfeatureunits