使用带有RNNs的Tensorflow和批次标准化

2024-06-09 01:31:50 发布

您现在位置:Python中文网/ 问答频道 /正文

我一直在追踪Tensorflow的一段断层。可以使用以下代码片段重现该问题:

import tensorflow as tf                                                                                                                                                                                                                       

with tf.device('/cpu:0'):
    xin = tf.placeholder(tf.float32, [None, 1, 1], name='input')
    rnn_cell = tf.contrib.rnn.LSTMCell(1)
    out, _ = tf.nn.dynamic_rnn(rnn_cell, xin, dtype=tf.float32)
    out = tf.layers.batch_normalization(out, training=True)
    out = tf.identity(out, name='output')

    optimiser = tf.train.AdamOptimizer(.0001)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        out = optimiser.minimize(out, global_step=tf.Variable(0, dtype=tf.float32), name='train_op')

config = tf.ConfigProto(allow_soft_placement = False)
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())

sample_in = [[[0]]]
sess.run(out, feed_dict={xin: sample_in})

我设法找到了这个问题,我得到了一个pull-request for it on github。如果我的一个修补程序运行的是以下错误:

^{pr2}$

这似乎表明我的示例代码存在拓扑问题。每当我将任何类型的RNN、批处理规范化和the required additional control dependency结合起来时,问题似乎就发生了

Batch normalisation note

我通过依赖^{}并将updates_collections参数设置为None来内联更新操作,从而设法减轻了这个问题。在

以下是更新后的代码示例以供参考:

import tensorflow as tf                                                                                                                                                                                                                       

with tf.device('/cpu:0'):
    xin = tf.placeholder(tf.float32, [None, 1, 1], name='input')
    rnn_cell = tf.contrib.rnn.LSTMCell(1)
    out, _ = tf.nn.dynamic_rnn(rnn_cell, xin, dtype=tf.float32)
    out = tf.contrib.layers.batch_norm(out, is_training=True, updates_collections=None)
    out = tf.identity(out, name='output')

    optimiser = tf.train.AdamOptimizer(.0001)
    out = optimiser.minimize(out, global_step=tf.Variable(0, dtype=tf.float32), name='train_op')

config = tf.ConfigProto(allow_soft_placement = False)
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())

sample_in = [[[0]]]
sess.run(out, feed_dict={xin: sample_in})

根据the documentation的说法,这可能会对性能产生负面影响,而且我不清楚我首先做错了什么。我的代码看起来正确吗?在

另外请注意,只有在Tensorflow使用XLA JIT支持构建时才会出现此问题,这可能是Tensorflow中的一个bug。在

编辑:我还提交了一个问题on Github


Tags: 代码namenoneconfigtfcelltrainout