如何在TensorFlow中打印张量对象的值?

2024-04-25 19:45:19 发布

您现在位置:Python中文网/ 问答频道 /正文

我一直在使用TensorFlow中矩阵乘法的介绍性示例。

matrix1 = tf.constant([[3., 3.]])
matrix2 = tf.constant([[2.],[2.]])
product = tf.matmul(matrix1, matrix2)

当我打印产品时,它显示为一个Tensor对象:

<tensorflow.python.framework.ops.Tensor object at 0x10470fcd0>

但是我怎么知道product的值呢?

以下情况没有帮助:

print product
Tensor("MatMul:0", shape=TensorShape([Dimension(1), Dimension(1)]), dtype=float32)

我知道图是在Sessions上运行的,但是没有任何方法可以检查Tensor对象的输出而不在session中运行图吗?


Tags: 对象示例产品tftensorflow矩阵producttensor
3条回答

计算Tensor对象的实际值的最简单方法是将其传递给Session.run()方法,或者在有默认会话时调用Tensor.eval()(即在with tf.Session():块中,或见下文)。一般来说,如果不在会话中运行某些代码,则无法打印张量的值。

如果您正在尝试编程模型,并且想要一种简单的方法来计算张量,^{}允许您在程序开始时打开一个会话,然后将该会话用于所有的Tensor.eval()(和Operation.run())调用。在交互式设置(如shell或IPython笔记本)中,当到处传递Session对象很乏味时,这可能更容易实现。例如,以下内容适用于Jupyter笔记本:

with tf.Session() as sess:  print(product.eval()) 

对于这样一个小的表达式来说,这可能看起来很傻,但是Tensorflow中的一个关键思想是延迟执行:构建一个大而复杂的表达式是非常便宜的,并且当您想要对其求值时,后端(与Session连接)能够更有效地调度其执行(例如执行独立的并行部件和使用GPU)。


[A]:要打印张量的值而不将其返回到Python程序,可以使用^{}运算符作为Andrzej suggests in another answer。请注意,您仍然需要运行图形的一部分以查看此操作的输出,该操作将打印为标准输出。如果运行的是分布式TensorFlow,tf.Print()将其输出打印到运行该op的任务的标准输出。这意味着,如果您使用https://colab.research.google.com例如,或任何其他Jupyter笔记本,那么您将在笔记本中看不到^{}的输出;在这种情况下,请参阅this answer关于如何使其静止打印。

[B]:您可以使用实验函数^{}来获取常量张量的值,但它不是一般用途,也不是为许多运算符定义的。

重复别人的话,不运行图表就不可能检查值。

下面是一个简单的代码片段,任何人都可以通过它来寻找一个简单的示例来打印值。在ipython笔记本中,代码可以在不做任何修改的情况下执行

import tensorflow as tf

#define a variable to hold normal random values 
normal_rv = tf.Variable( tf.truncated_normal([2,3],stddev = 0.1))

#initialize the variable
init_op = tf.initialize_all_variables()

#run the graph
with tf.Session() as sess:
    sess.run(init_op) #execute init_op
    #print the random values that we sample
    print (sess.run(normal_rv))

输出:

[[-0.16702934  0.07173464 -0.04512421]
 [-0.02265321  0.06509651 -0.01419079]]

虽然其他的答案是正确的,即在对图形求值之前不能打印该值,但它们并没有讨论一种简单的方法,即在对图形求值之后,在图形中实际打印一个值。

每当对图求值(使用runeval)时,查看张量值的最简单方法是使用^{}操作,如下例所示:

# Initialize session
import tensorflow as tf
sess = tf.InteractiveSession()

# Some tensor we want to print the value of
a = tf.constant([1.0, 3.0])

# Add print operation
a = tf.Print(a, [a], message="This is a: ")

# Add more elements of the graph using a
b = tf.add(a, a)

现在,每当我们评估整个图表时,例如使用b.eval(),我们得到:

I tensorflow/core/kernels/logging_ops.cc:79] This is a: [1 3]

相关问题 更多 >