tf.get_集合提取一个scop的变量

2024-06-09 12:40:18 发布

您现在位置:Python中文网/ 问答频道 /正文

我在每个作用域中定义了n(例如:n=3)个作用域和x(例如:x=4)个变量。 范围包括:

model/generator_0
model/generator_1
model/generator_2

一旦我计算了损失,我想在运行时根据一个标准从一个作用域中提取并提供所有变量。因此,我选择的作用域idx的索引是一个argmin张量,转换成int32

<tf.Tensor 'model/Cast:0' shape=() dtype=int32>

我已经试过了:

train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'model/generator_'+tf.cast(idx, tf.string)) 

这显然不起作用。 是否有任何方法可以使用idx将属于该特定范围的所有x变量传递给优化器。

提前谢谢!

维涅什斯里尼瓦桑


Tags: 标准model定义tftraingenerator作用域tensor