输入到tensorflow的值应该是排名1还是排名2?

2024-04-19 07:48:42 发布

您现在位置:Python中文网/ 问答频道 /正文

我尝试用in_top_k函数来测试这个函数到底在做什么。但我发现了一些令人困惑的行为。在

首先我编码如下

import numpy as np
import tensorflow as tf

target = tf.constant(np.random.randint(2, size=30).reshape(30,-1), dtype=tf.int32, name="target")
pred = tf.constant(np.random.rand(30,1), dtype=tf.float32, name="pred")
result = tf.nn.in_top_k(pred, target, 1)

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    targetVal = target.eval()
    predVal = pred.eval()
    resultVal = result.eval()

然后生成以下错误:

^{pr2}$

然后我把代码改成

import numpy as np
import tensorflow as tf
target = tf.constant(np.random.randint(2, size=30), dtype=tf.int32, name="target")
pred = tf.constant(np.random.rand(30,1).reshape(-1), dtype=tf.float32, name="pred")
result = tf.nn.in_top_k(pred, target, 1)

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    targetVal = target.eval()
    predVal = pred.eval()
    resultVal = result.eval()

但现在错误变成了

ValueError: Shape must be rank 2 but is rank 1 for 'in_top_k/InTopKV2' (op: 'InTopKV2') with input shapes: [30], [30], [].

那么输入应该是秩1还是秩2?在


Tags: nameinimporttargetinittftopas
1条回答
网友
1楼 · 发布于 2024-04-19 07:48:42

对于in_top_ktargets必须是秩1(类索引)和{}秩2(每个类的分数)。这很容易看出from the docs
这意味着这两个错误消息实际上每次都会抱怨不同的输入(第一次是目标,第二次是预测),有趣的是消息中根本没有提到这一点。。。不管怎样,下面的片段应该更像:

import numpy as np
import tensorflow as tf

target = tf.constant(np.random.randint(2, size=30), dtype=tf.int32, name="target")
pred = tf.constant(np.random.rand(30,1), dtype=tf.float32, name="pred")
result = tf.nn.in_top_k(pred, target, 1)

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    targetVal = target.eval()
    predVal = pred.eval()
    resultVal = result.eval()

在这里,我们基本上结合了“最好的两个片段”:第一个片段的预测和第二个片段的目标。然而,以我理解文档的方式,即使是二进制分类,我们也需要两个值用于预测,每个类一个值。所以有点像

^{pr2}$

相关问题 更多 >