我尝试用in_top_k函数来测试这个函数到底在做什么。但我发现了一些令人困惑的行为。在
首先我编码如下
import numpy as np
import tensorflow as tf
target = tf.constant(np.random.randint(2, size=30).reshape(30,-1), dtype=tf.int32, name="target")
pred = tf.constant(np.random.rand(30,1), dtype=tf.float32, name="pred")
result = tf.nn.in_top_k(pred, target, 1)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
targetVal = target.eval()
predVal = pred.eval()
resultVal = result.eval()
然后生成以下错误:
^{pr2}$然后我把代码改成
import numpy as np
import tensorflow as tf
target = tf.constant(np.random.randint(2, size=30), dtype=tf.int32, name="target")
pred = tf.constant(np.random.rand(30,1).reshape(-1), dtype=tf.float32, name="pred")
result = tf.nn.in_top_k(pred, target, 1)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
targetVal = target.eval()
predVal = pred.eval()
resultVal = result.eval()
但现在错误变成了
ValueError: Shape must be rank 2 but is rank 1 for 'in_top_k/InTopKV2' (op: 'InTopKV2') with input shapes: [30], [30], [].
那么输入应该是秩1还是秩2?在
对于}秩2(每个类的分数)。这很容易看出from the docs。
in_top_k
,targets
必须是秩1(类索引)和{这意味着这两个错误消息实际上每次都会抱怨不同的输入(第一次是目标,第二次是预测),有趣的是消息中根本没有提到这一点。。。不管怎样,下面的片段应该更像:
在这里,我们基本上结合了“最好的两个片段”:第一个片段的预测和第二个片段的目标。然而,以我理解文档的方式,即使是二进制分类,我们也需要两个值用于预测,每个类一个值。所以有点像
^{pr2}$相关问题 更多 >
编程相关推荐