在二维tf.Variab中使用tf.scatter_update

2024-05-16 19:07:15 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在跟踪这个Manipulating matrix elements in tensorflow。使用tf.scatter_更新。但我的问题是: 如果我的tf.变量是2D会发生什么?比如说:

a = tf.Variable(initial_value=[[0, 0, 0, 0],[0, 0, 0, 0]])

例如,如何更新每行的第一个元素并将值1赋给它?

我试过类似的

for line in range(2):
    sess.run(tf.scatter_update(a[line],[0],[1]))

但它失败了(我早就预料到了),给了我一个错误:

TypeError: Input 'ref' of 'ScatterUpdate' Op requires l-value input

我怎样才能解决这种问题?

`


Tags: in元素forvaluetftensorflowlinerange
2条回答

我发现了一些东西here 我创建了一个变量名U=[[1,2,3],[4,5,6]]并希望像U[:,1]=[2,3]那样更新它,所以我做了U[:,1].assign(将U转换成U张量[2,3])

这里有一个简单的代码

x = tf.Variable([[1,2,3],[4,5,6]])
print K.eval(x)
y = [0, 0]
with tf.control_dependencies([x[:,1].assign(y)]):
    x = tf.identity(x)
print K.eval(x)

在tensorflow中,不能更新张量,但可以更新变量。

scatter_update运算符只能更新变量的第一个维度。 必须始终将引用张量传递给散点更新(a,而不是a[line])。

以下是如何更新变量的第一个元素:

import tensorflow as tf

g = tf.Graph()
with g.as_default():
    a = tf.Variable(initial_value=[[0, 0, 0, 0],[0, 0, 0, 0]])
    b = tf.scatter_update(a, [0, 1], [[1, 0, 0, 0], [1, 0, 0, 0]])

with tf.Session(graph=g) as sess:
   sess.run(tf.initialize_all_variables())
   print sess.run(a)
   print sess.run(b)

输出:

[[0 0 0 0]
 [0 0 0 0]]
[[1 0 0 0]
 [1 0 0 0]]

但是,必须再次改变整个张量,可能会更快地分配一个全新的张量。

相关问题 更多 >