我一直在追踪Tensorflow的一段断层。可以使用以下代码片段重现该问题:
import tensorflow as tf
with tf.device('/cpu:0'):
xin = tf.placeholder(tf.float32, [None, 1, 1], name='input')
rnn_cell = tf.contrib.rnn.LSTMCell(1)
out, _ = tf.nn.dynamic_rnn(rnn_cell, xin, dtype=tf.float32)
out = tf.layers.batch_normalization(out, training=True)
out = tf.identity(out, name='output')
optimiser = tf.train.AdamOptimizer(.0001)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
out = optimiser.minimize(out, global_step=tf.Variable(0, dtype=tf.float32), name='train_op')
config = tf.ConfigProto(allow_soft_placement = False)
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())
sample_in = [[[0]]]
sess.run(out, feed_dict={xin: sample_in})
我设法找到了这个问题,我得到了一个pull-request for it on github。如果我的一个修补程序运行的是以下错误:
^{pr2}$这似乎表明我的示例代码存在拓扑问题。每当我将任何类型的RNN、批处理规范化和the required additional control dependency结合起来时,问题似乎就发生了
我通过依赖^{updates_collections
参数设置为None
来内联更新操作,从而设法减轻了这个问题。在
以下是更新后的代码示例以供参考:
import tensorflow as tf
with tf.device('/cpu:0'):
xin = tf.placeholder(tf.float32, [None, 1, 1], name='input')
rnn_cell = tf.contrib.rnn.LSTMCell(1)
out, _ = tf.nn.dynamic_rnn(rnn_cell, xin, dtype=tf.float32)
out = tf.contrib.layers.batch_norm(out, is_training=True, updates_collections=None)
out = tf.identity(out, name='output')
optimiser = tf.train.AdamOptimizer(.0001)
out = optimiser.minimize(out, global_step=tf.Variable(0, dtype=tf.float32), name='train_op')
config = tf.ConfigProto(allow_soft_placement = False)
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())
sample_in = [[[0]]]
sess.run(out, feed_dict={xin: sample_in})
根据the documentation的说法,这可能会对性能产生负面影响,而且我不清楚我首先做错了什么。我的代码看起来正确吗?在
另外请注意,只有在Tensorflow使用XLA JIT支持构建时才会出现此问题,这可能是Tensorflow中的一个bug。在
编辑:我还提交了一个问题on Github
目前没有回答
相关问题 更多 >
编程相关推荐