TensorFlow:在图形中使用tensor作为列表参数

2024-06-16 13:13:00 发布

您现在位置:Python中文网/ 问答频道 /正文

我想使用tf.张量作为TensorFlow图中另一个操作的(python)list类型的参数。 换句话说,我想使用一个张量作为另一个操作的动态列表参数。这可能吗?你知道吗

可执行示例:

import tensorflow as tf
import numpy as np

graph = tf.Graph()
var1 = np.random.randn(2, 3)
var2 = np.random.randn(2, 3, 4)

with graph.as_default():
    def getRange(myTensor):
        myRank = tf.rank(myTensor)
        return tf.range(tf.constant(1), tf.squeeze(myRank))

    def getMoments(myTensor):
        myMoments = tf.nn.moments(myTensor, axes=getRange(myTensor))
        return myMoments
    var1tf = tf.Variable(var1)
    var2tf = tf.Variable(var2)
    var1moments = getMoments(var1tf)
    var2moments = getMoments(var2tf)
    rangeVar1 = getRange(var1tf)
    rangeVar2 = getRange(var2tf)
    init = tf.global_variables_initializer()

with tf.Session(graph=graph) as sess:
    sess.run([init])
    print(sess.run([rangeVar1])) # outputs [1], ok
    print(sess.run([rangeVar2])) # outputs [1, 2], ok
    print(sess.run([var1moments]))
    print(sess.run([var2moments]))

它抛出:

raise TypeError("'Tensor' object is not iterable.")

Tags: runimport参数tfasnpgraphsess
1条回答
网友
1楼 · 发布于 2024-06-16 13:13:00

我使用get\u shape()找到了一个解决方案:

def getMoments(myTensor):
    myRank = len(myTensor.get_shape().as_list())
    print('rank via shape:', myRank)
    myMoments = tf.nn.moments(myTensor, axes=list(range(1, myRank)))
    return myMoments

相关问题 更多 >