Tensorflow,在RNN中保存状态的最佳方法?

2024-05-12 21:29:49 发布

您现在位置:Python中文网/ 问答频道 /正文

我现在有下面的代码用于tensorflow中一系列链接在一起的rnn。我没有使用MultiRNN,因为我稍后要对每个层的输出做一些事情。

 for r in range(RNNS):
    with tf.variable_scope('recurent_%d' % r) as scope:
        state = [tf.zeros((BATCH_SIZE, sz)) for sz in rnn_func.state_size]
        time_outputs = [None] * TIME_STEPS

        for t in range(TIME_STEPS):
            rnn_input = getTimeStep(rnn_outputs[r - 1], t)
            time_outputs[t], state = rnn_func(rnn_input, state)
            time_outputs[t] = tf.reshape(time_outputs[t], (-1, 1, RNN_SIZE))
            scope.reuse_variables()
        rnn_outputs[r] = tf.concat(1, time_outputs)

目前我有固定的时间步数。不过,我希望将其更改为只有一个timestep,但记住批处理之间的状态。因此,我需要为每个层创建一个状态变量,并为其分配每个层的最终状态。像这样的东西。

for r in range(RNNS):
    with tf.variable_scope('recurent_%d' % r) as scope:
        saved_state = tf.get_variable('saved_state', ...)
        rnn_outputs[r], state = rnn_func(rnn_outputs[r - 1], saved_state)
        saved_state = tf.assign(saved_state, state)

然后,对于每个层,我需要评估sess.run函数中保存的状态,并调用我的训练函数。我需要对每个rnn层都这样做。这看起来有点麻烦。我需要跟踪每个保存的状态并在运行时对其进行评估。另外,run需要将状态从我的GPU复制到主机内存,这将是低效和不必要的。有更好的办法吗?


Tags: infortime状态tfwithrangeoutputs
3条回答

我现在使用tf.control_依赖项保存RNN状态。这是一个例子。

 saved_states = [tf.get_variable('saved_state_%d' % i, shape = (BATCH_SIZE, sz), trainable = False, initializer = tf.constant_initializer()) for i, sz in enumerate(rnn.state_size)]
        W = tf.get_variable('W', shape = (2 * RNN_SIZE, RNN_SIZE), initializer = tf.truncated_normal_initializer(0.0, 1 / np.sqrt(2 * RNN_SIZE)))
        b = tf.get_variable('b', shape = (RNN_SIZE,), initializer = tf.constant_initializer())

        rnn_output, states = rnn(last_output, saved_states)
        with tf.control_dependencies([tf.assign(a, b) for a, b in zip(saved_states, states)]):
            dense_input = tf.concat(1, (last_output, rnn_output))

        dense_output = tf.tanh(tf.matmul(dense_input, W) + b)
        last_output = dense_output + last_output

我只是确保我的图形的一部分依赖于保存状态。

下面是通过定义状态变量来更新LSTM初始状态的代码。它还支持多个层。

我们定义了两个函数-一个用于获取初始状态为零的状态变量,另一个用于返回操作,我们可以传递给session.run,以便用LSTM的最后一个隐藏状态更新状态变量。

def get_state_variables(batch_size, cell):
    # For each layer, get the initial state and make a variable out of it
    # to enable updating its value.
    state_variables = []
    for state_c, state_h in cell.zero_state(batch_size, tf.float32):
        state_variables.append(tf.contrib.rnn.LSTMStateTuple(
            tf.Variable(state_c, trainable=False),
            tf.Variable(state_h, trainable=False)))
    # Return as a tuple, so that it can be fed to dynamic_rnn as an initial state
    return tuple(state_variables)


def get_state_update_op(state_variables, new_states):
    # Add an operation to update the train states with the last state tensors
    update_ops = []
    for state_variable, new_state in zip(state_variables, new_states):
        # Assign the new state to the state variables on this layer
        update_ops.extend([state_variable[0].assign(new_state[0]),
                           state_variable[1].assign(new_state[1])])
    # Return a tuple in order to combine all update_ops into a single operation.
    # The tuple's actual value should not be used.
    return tf.tuple(update_ops)

我们可以使用它在每个批处理之后更新LSTM的状态。请注意,我使用tf.nn.dynamic_rnn展开:

data = tf.placeholder(tf.float32, (batch_size, max_length, frame_size))
cell_layer = tf.contrib.rnn.GRUCell(256)
cell = tf.contrib.rnn.MultiRNNCell([cell] * num_layers)

# For each layer, get the initial state. states will be a tuple of LSTMStateTuples.
states = get_state_variables(batch_size, cell)

# Unroll the LSTM
outputs, new_states = tf.nn.dynamic_rnn(cell, data, initial_state=states)

# Add an operation to update the train states with the last state tensors.
update_op = get_state_update_op(states, new_states)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run([outputs, update_op], {data: ...})

this answer的主要区别在于state_is_tuple=True使LSTM的状态成为包含两个变量(单元状态和隐藏状态)的LSTMStateTuple,而不仅仅是一个变量。然后使用多个层使LSTM的状态成为LSTMStateTuples的元组-每层一个。

重置为零

使用经过训练的模型进行预测/解码时,可能需要将状态重置为零。然后,您可以使用此功能:

def get_state_reset_op(state_variables, cell, batch_size):
    # Return an operation to set each variable in a list of LSTMStateTuples to zero
    zero_states = cell.zero_state(batch_size, tf.float32)
    return get_state_update_op(state_variables, zero_states)

例如,如上所述:

reset_state_op = get_state_reset_op(state, cell, max_batch_size)
# Reset the state to zero before feeding input
sess.run([reset_state_op])
sess.run([outputs, update_op], {data: ...})

相关问题 更多 >