我试图在tensor-flow
中实现一个批处理规范化层。我可以使用tf.moments
来获得平均值和方差。在
对于测试时间,我想设置一个指数移动平均值来跟踪均值和方差。我试着这样做:
def batch_normalized_linear_layer(state_below, scope_name, n_inputs, n_outputs, stddev, wd, eps=.0001):
with tf.variable_scope(scope_name) as scope:
weight = _variable_with_weight_decay(
"weights", shape=[n_inputs, n_outputs],
stddev=stddev, wd=wd
)
act = tf.matmul(state_below, weight)
# get moments
act_mean, act_variance = tf.nn.moments(act, [0])
# get mean and variance variables
mean = _variable_on_cpu('bn_mean', [n_outputs], tf.constant_initializer(0.0))
variance = _variable_on_cpu('bn_variance', [n_outputs], tf.constant_initializer(1.0))
# assign the moments
assign_mean = mean.assign(act_mean)
assign_variance = variance.assign(act_variance)
act_bn = tf.mul((act - mean), tf.rsqrt(variance + eps), name=scope.name+"_bn")
beta = _variable_on_cpu("beta", [n_outputs], tf.constant_initializer(0.0))
gamma = _variable_on_cpu("gamma", [n_outputs], tf.constant_initializer(1.0))
bn = tf.add(tf.mul(act_bn, gamma), beta)
output = tf.nn.relu(bn, name=scope.name)
_activation_summary(output)
return output, mean, variance
其中,cpu上的_变量定义为:
^{pr2}$我相信我已经准备好了
assign_mean = mean.assign(act_mean)
assign_variance = variance.assign(act_variance)
不正确,但我不知道怎么做。当我使用tensorboard跟踪这些均值和方差变量时,它们与它们的初始值是平的。在
Rafal的评论抓住了问题的核心:您没有运行assign节点。您可以尝试使用我在另一个答案中发布的batchnorm帮助器How could I use Batch Normalization in TensorFlow?-或者您可以按照他的建议,通过添加依赖项来强制赋值。在
一般原则是,只有当数据或控件依赖项“通过”一个节点时,才应该指望该节点正在运行。
with_dependencies
确保在使用输出操作之前,指定的依赖关系已经完成。在相关问题 更多 >
编程相关推荐