2024-03-28 14:52:14 发布
网友
我将logits与循环中的标签进行比较:
for r in range(logits.shape[0]): if labels[r] == np.argmax(logits[r]): guessed += 1.0
其中labels是整数标签的一维数组,logits是二维数组,第二维是标签的概率。你知道吗
labels
logits
上面的解决方案是一个Python循环,效率不高。应该有一个常用的numpy或tensorflow快捷方式来做到这一点。你能推荐一个吗?你知道吗
numpy
tensorflow
您可以通过np.argmax(logits,axis=1)一次获得所有最大值。以下内容可以替换for循环以获得猜测的总数:
np.argmax(logits,axis=1)
guessed = np.sum(labels == np.argmax(logits,axis=1))
您可以通过
np.argmax(logits,axis=1)
一次获得所有最大值。以下内容可以替换for循环以获得猜测的总数:相关问题 更多 >
编程相关推荐