Keras如何使用argmax进行预测

2024-04-19 09:45:21 发布

您现在位置:Python中文网/ 问答频道 /正文

我有3类Tree, Stump, Ground。我列出了这些类别的列表:

CATEGORIES = ["Tree", "Stump", "Ground"]

当我打印我的预测时,它给了我

[[0. 1. 0.]]

我已经读过关于numpy的Argmax,但我不完全确定如何在这种情况下使用它。

我试过用

print(np.argmax(prediction))

但这给了我1的输出。这很好,但我想找出1的索引,然后打印出类别而不是最高值。

import cv2
import tensorflow as tf
import numpy as np

CATEGORIES = ["Tree", "Stump", "Ground"]


def prepare(filepath):
    IMG_SIZE = 150 # This value must be the same as the value in Part1
    img_array = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE)
    new_array = cv2.resize(img_array, (IMG_SIZE, IMG_SIZE))
    return new_array.reshape(-1, IMG_SIZE, IMG_SIZE, 1)

# Able to load a .model, .h3, .chibai and even .dog
model = tf.keras.models.load_model("models/test.model")

prediction = model.predict([prepare('image.jpg')])
print("Predictions:")
print(prediction)
print(np.argmax(prediction))

我希望我的预测能告诉我:

Predictions:
[[0. 1. 0.]]
Stump

谢谢你的阅读:)我非常感谢你的帮助。


Tags: importtreeimgsizemodelasnp类别