如何给张量流常数赋值?

2024-04-26 00:24:27 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在从一个.pb文件加载一个TensorFlow模型。我想改变所有层的权重。我可以提取权重,但我不能更改权重。你知道吗

我将graph_def模型转换为TensorFlow模型,但即使这样,我也无法为权重赋值,因为权重存储在“Const”类型的张量中。你知道吗

b = graph_tf.get_tensor_by_name("Variable_1:0")      
tf.assign(b, np.ones((1,1,64,64))) 

我得到以下错误:

AttributeError: 'Tensor' object has no attribute 'assign'

请提供解决此问题的方法。提前谢谢。你知道吗


Tags: 文件模型类型getbytftensorflowdef
1条回答
网友
1楼 · 发布于 2024-04-26 00:24:27

这里有一个方法你可以达到这样的目标。您希望用初始化为这些操作的值的变量替换某些常量操作,因此可以首先提取这些常量值,然后用初始化为这些值的变量创建图。请参见下面的示例。你知道吗

import tensorflow as tf

# Example graph
with tf.Graph().as_default():
    inp = tf.placeholder(tf.float32, [None, 3], name='Input')
    w = tf.constant([[1.], [2.], [3.]], tf.float32, name='W')
    out = tf.squeeze(inp @ w, 1, name='Output')
    gd = tf.get_default_graph().as_graph_def()

# Extract weight values
with tf.Graph().as_default():
    w, = tf.graph_util.import_graph_def(gd, return_elements=['W:0'])
    # Get the constant weight values
    with tf.Session() as sess:
        w_val = sess.run(w)
    # Alternatively, since it is a constant,
    # you can get the values from the operation attribute directly
    w_val = tf.make_ndarray(w.op.get_attr('value'))

# Make new graph
with tf.Graph().as_default():
    # Make variables initialized with stored values
    w = tf.Variable(w_val, name='W')
    init_op = tf.global_variables_initializer()
    # Import graph
    inp, out = tf.graph_util.import_graph_def(
        gd, input_map={'W:0': w},
        return_elements=['Input:0', 'Output:0'])
    # Change value operation
    w_upd = w[2].assign([5.])
    # Test
    with tf.Session() as sess:
        sess.run(init_op)
        print(sess.run(w))
        # [[1.]
        #  [2.]
        #  [3.]]
        sess.run(w_upd)
        print(sess.run(w))
        # [[1.]
        #  [2.]
        #  [5.]]

相关问题 更多 >