Tensorflow:替换tf.nn.rnn_cell._linear(输入,大小,0,范围)

2024-06-16 19:10:58 发布

您现在位置:Python中文网/ 问答频道 /正文

我正试图让SequenceGAN(https://github.com/LantaoYu/SeqGAN)从https://arxiv.org/pdf/1609.05473.pdf运行。
在修复了诸如用stack替换pack之类的明显错误之后,它仍然不运行,因为公路网部分需要tf.nn.rnn_cell._linear函数:

# highway layer that borrowed from https://github.com/carpedm20/lstm-char-cnn-tensorflow
def highway(input_, size, layer_size=1, bias=-2, f=tf.nn.relu):
    """Highway Network (cf. http://arxiv.org/abs/1505.00387).

    t = sigmoid(Wy + b)
    z = t * g(Wy + b) + (1 - t) * y
    where g is nonlinearity, t is transform gate, and (1 - t) is carry gate.
    """
    output = input_
    for idx in range(layer_size):
        output = f(tf.nn.rnn_cell._linear(output, size, 0, scope='output_lin_%d' % idx)) #tf.contrib.layers.linear instad doesn't work either.
        transform_gate = tf.sigmoid(tf.nn.rnn_cell._linear(input_, size, 0, scope='transform_lin_%d' % idx) + bias)
        carry_gate = 1. - transform_gate

        output = transform_gate * output + carry_gate * input_

    return output

在Tensorflow 1.0或0.12中,tf.nn.rnn_cell._linear函数似乎不再存在,我不知道用什么来替换它。我找不到这个的任何新实现,也找不到关于tensorflow的github或(不幸的是非常稀疏的)文档的任何信息。

有人知道这个功能的新挂件吗? 提前多谢!


Tags: httpsgithublayerinputoutputsizeistf
3条回答

若何·若特西的回答几乎是正确的: 然而,linear定义并不位于tf.contrib.rnn.basicRNNCell,而是分别位于tf.contrib.rnn.python.ops.rnn_celltf.contrib.rnn.python.ops.core_rnn_cell_impl

你可以找到他们的源代码herehere

在1.0版中,东西到处移动。我也有过类似的搜索更新tf.nn.rnn_cell.LSTMCelltf.contrib.rnn.BasicLSTMCell

对于你的情况,tf.nn.rnn_cell._linear现在存在于tf.contrib.rnn.python.ops.core_rnn_cell_impl以及BasicRNNCell的定义中。检查BasicRNNCell docssource code,我们在L113-L118中看到了使用{u线性。

  def __call__(self, inputs, state, scope=None):
    """Most basic RNN: output = new_state = act(W * input + U * state + B)."""
    with _checked_scope(self, scope or "basic_rnn_cell", reuse=self._reuse):
      output = self._activation(
          _linear([inputs, state], self._num_units, True))
    return output, output

线性方法在line 854定义为:
Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.

祝你好运!

我在使用SkFlow的TensorFlowDNRegressor时遇到了这个错误。 第一次看到若霍若特的回答,我有点困惑。 但第二天我明白了他的意思。

以下是我的工作:

from tensorflow.python.ops import rnn_cell_impl

rnn_cell_impl._linear替换tf.nn.rnn_cell._linear

相关问题 更多 >