如何调试TensorFlow中的NaN值?

2024-06-06 19:18:29 发布

您现在位置:Python中文网/ 问答频道 /正文

我在运行TensorFlow,碰巧遇到了一个NaN。我想知道是什么,但我不知道怎么做。主要的问题是,在一个“正常”的过程程序中,我只需在执行操作之前编写一个print语句。TensorFlow的问题是,我不能这样做,因为我首先声明(或定义)图形,所以向图形定义添加print语句没有帮助。有没有什么规则,建议,启发法,任何东西来追踪什么可能导致南部?


在这种情况下,我更确切地知道要看哪一行,因为我有以下几点:

Delta_tilde = 2.0*tf.matmul(x,W) - tf.add(WW, XX) #note this quantity should always be positive because its pair-wise euclidian distance
Z = tf.sqrt(Delta_tilde)
Z = Transform(Z) # potentially some transform, currently I have it to return Z for debugging (the identity)
Z = tf.pow(Z, 2.0)
A = tf.exp(Z) 

当这一行出现时,我有它返回我的摘要作者声明的NaN。这是为什么?有没有办法至少探究Z在平方根之后的值?


对于我发布的特定示例,我尝试了tf.Print(0,Z),但没有成功,它什么也没有打印出来。如所示:

Delta_tilde = 2.0*tf.matmul(x,W) - tf.add(WW, XX) #note this quantity should always be positive because its pair-wise euclidian distance
Z = tf.sqrt(Delta_tilde)
tf.Print(0,[Z]) # <-------- TF PRINT STATMENT
Z = Transform(Z) # potentially some transform, currently I have it to return Z for debugging (the identity)
Z = tf.pow(Z, 2.0)
A = tf.exp(Z) 

我真的不明白tf.Print应该做什么。为什么它需要两个论点?如果我想打印1个张量,为什么我需要通过2?我觉得很奇怪。


我在看函数tf.add_check_numerics_ops(),但它没有说明如何使用它(加上文档似乎没有太大帮助)。有人知道怎么用这个吗?


因为我已经有了处理数据可能不好的评论,所以我使用标准MNIST。然而,我正在计算一个正的量(成对欧几里德距离),然后平方根它。因此,我看不出具体的数据是如何成为一个问题的。


Tags: add声明图形定义tftensorflow语句nan
3条回答

获得NaN结果的原因有很多,通常是因为学习率太高,但也有很多其他原因,例如输入队列中的数据损坏或0计算日志。

无论如何,用您描述的print进行调试不能用简单的print来完成(因为这只会导致在图中打印张量信息,而不会打印任何实际值)。

但是,如果在构建图(tf.print)时使用tf.print作为操作,那么当执行图时,将打印实际值(观察这些值是一个很好的练习,可以调试和理解网络的行为)。

但是,您使用print语句的方式并不完全正确。这是一个op,因此需要传递一个张量并请求一个结果张量,稍后在执行图中需要使用它。否则不会执行操作,也不会进行打印。试试这个:

Z = tf.sqrt(Delta_tilde)
Z = tf.Print(Z,[Z], message="my Z-values:") # <-------- TF PRINT STATMENT
Z = Transform(Z) # potentially some transform, currently I have it to return Z for debugging (the identity)
Z = tf.pow(Z, 2.0)

我曾经发现,要找出nans和infs可能出现的位置比修复bug要困难得多。作为对@scai回答的补充,我想在这里补充几点:

调试模块,可以通过以下方式导入:

from tensorflow.python import debug as tf_debug

比任何印刷品或断言都好得多。

您可以通过以下方式更改会话的包装器来添加调试函数:

sess = tf_debug.LocalCLIDebugWrapperSession(sess)
sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)

然后提示命令行界面,然后输入: run -f has_inf_or_nanlt -f has_inf_or_nan来查找nans或inf的位置。第一个是灾难发生的第一个地方。通过变量名,您可以跟踪代码中的源代码。

引用:https://developers.googleblog.com/2017/02/debug-tensorflow-models-with-tfdbg.html

看起来你可以在完成图表制作后调用它。

check = tf.add_check_numerics_ops()

我想这会增加对所有浮点操作的检查。然后在sessions run函数中添加检查操作。

sess.run([check, ...])

相关问题 更多 >