分批范数提取TensorF的运行均值和运行方差

2024-04-25 07:30:10 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图研究通过GCMLE(saved_model.pbassets/*&;variables/*)导出的训练张量流模型的运行均值和运行方差。这些值保存在图表中的什么位置?我可以从tf.GraphKeys.TRAINABLE_VARIABLES中获取γ/β值,但我无法找到任何tf.GraphKeys.MODEL_VARIABLES中的运行平均值和运行方差。运行平均值和运行方差是否存储在其他地方?在

我知道在测试时(即Modes.EVAL),运行平均值和运行方差被用来规范化输入的数据,然后使用gamma和beta对规范化数据进行缩放和移位。我试图在推理时查看所有我需要的变量,但是我找不到运行平均值和运行方差。它们是否只在测试时使用而不在推理时使用(Modes.PREDICT)?如果是这样的话,这就可以解释为什么在导出的模型中找不到它们,但我希望它们在那里。在

基于tf.GraphKeys我尝试过其他类似tf.GraphKeys.MOVING_AVERAGE_VARIABLES的方法,但它们也是空的。我在批处理规范化文档中也看到了这一行“注意:在培训时,移动平均值和移动方差需要更新。默认情况下,更新操作放在tf.GraphKeys.UPDATE_OPS中,因此需要将它们作为一个依赖项添加到train\op中。”于是我尝试从保存的模型中查看tf.GraphKeys.UPDATE_OPS,它们包含一个assign opbatch_normalization/AssignMovingAvg:0,但仍然不清楚从何处获取值。在


Tags: 数据模型modeltfupdatevariables规范化ops
2条回答

移动平均值和移动方差似乎存储在tf.GraphKeys.GLOBAL_VARIABLES中,看起来MODEL_VARIABLES中没有显示任何内容是因为您需要使用tf.contrib.framework.local_variable

除了#reese0106的答案,
如果您想去掉BatchNorm的移动平均值、移动方差,
您可以用以下名称对它们进行索引。在

vars = tf.global_variables() # shows every variable being used.
vars_moving_mean_variance = []
for var in vars:
    if ("moving_mean" in var.name) or ("moving_variance" in var.name):
        vars_moving_mean_variance.append(var)

print(vars_moving_mean_variance)


p、 谢谢你的提问和回答。我也解决了我自己的问题。

相关问题 更多 >