使用十位数轻松检查自定义渐变

2024-04-25 20:59:15 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个非常简单的问题:我在Tensorflow中实现了一个复杂的自定义操作及其梯度,假设前进方向是正确的,我想知道是否有一种简单的方法来检查有限差分是否很好地逼近了不同点的自定义渐变,而不必以丑陋的方式重新实现它。我看到了函数tf.test.compute_gradient_error()in the official doc,但是源代码很密集,很难阅读,我似乎找不到任何其他相关的问题或示例。
不过,我肯定有一个超级简单的自我包含的例子,我错过了? 编辑: 例如,如果我尝试:

import tensorflow as tf
import numpy as np
start=np.random.normal(size=(100,1)).astype("float32")
x=tf.Variable(start)
w=2*tf.ones((1,1),dtype="float32")
y=tf.matmul(x,w)
#I differentiate y wrt x, which is a variable
check=tf.test.compute_gradient_error(x,[100,1],y,[100,1],x_init_value=start)
sess=tf.Session()
sess.run(tf.initialize_all_variables())
sess.run(check)

它抛出: AttributeError:“NoneType”对象没有属性“run” 寻找梯度_棋盘格.py我做错什么了?在


Tags: runtestimporttftensorflowcheckasnp
1条回答
网友
1楼 · 发布于 2024-04-25 20:59:15

所以我的问题是gradient_checker.py调用get_default_session()来获取它所使用的会话,如果操作没有显式地连接到正在使用的会话,这显然是不起作用的,这是通过限定op的作用域来完成的:

with sess.as_default_session():
  check=tf.test.compute_gradient_error()
  print check

还需要说明的是,之所以需要这样做,是因为检查是张量的sess.run()的结果,而不是像大多数张量流函数那样的图中的节点。在

相关问题 更多 >