如何解决:InvalidArgumentError:图执行错误?

0 投票
0 回答
15 浏览
提问于 2025-04-12 08:11

我刚开始学习计算机视觉和分类,想用Keras的VGG16模型进行迁移学习,但在运行下面的代码时总是出现错误,有谁能帮帮我或者给我一些建议吗?

import tensorflow as tf
from tensorflow.keras import models, layers
import matplotlib.pyplot as plt
from glob import glob
import os

IMG_SIZE = 224
BATCH_SIZE = 32
CHANNELS = 3
EPOCH = 30

dataset = tf.keras.preprocessing.image_dataset_from_directory(
    "../Notebooks/Dataset",
    shuffle=True,
    image_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    
)

class_names = dataset.class_names
class_names

def preprocess_image(image):
    image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
    image /= 255.0  # normalize to [0,1] range
    print("HI")
    print(image)
    return image

# Apply preprocessing and augmentation
dataset = dataset.map(
    lambda x, y: (preprocess_image(x), y),
    num_parallel_calls=tf.data.experimental.AUTOTUNE
)

def apply_augmentation(image, label):
    # Random horizontal flip
    image = tf.image.random_flip_left_right(image)
    # Random rotation
    image = tf.image.rot90(image, tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32))
    # Random brightness adjustment
    image = tf.image.random_brightness(image, max_delta=0.1)
    return image, label

dataset = dataset.map(
    apply_augmentation,
    num_parallel_calls=tf.data.experimental.AUTOTUNE
)


train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset = dataset.take(train_size)
val_dataset = dataset.skip(train_size).take(val_size)
test_dataset = dataset.skip(train_size).skip(val_size)

vgg = tf.keras.applications.VGG16(
    input_shape=(IMG_SIZE, IMG_SIZE, 3),
    weights='imagenet',
    include_top=False
)


for layer in vgg.layers:
    layer.trainable = False

x = layers.Flatten()(vgg.output)
prediction = layers.Dense(3, activation='relu')(x)

model = tf.keras.models.Model(inputs=vgg.input, outputs=prediction)
model.summary()

model.compile(
    loss='sparse_categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy']
)

history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=10,
    steps_per_epoch=train_size,
    validation_steps=val_size
)

当我开始训练模型时,出现了以下错误,

2024-03-29 09:36:18.734909: W tensorflow/core/framework/op_kernel.cc:1839] OP_REQUIRES failed at sparse_xent_op.cc:103 : INVALID_ARGUMENT: Received a label value of 3 which is outside the valid range of [0, 3).  Label values: 2 3 1 0 1 2 2 2 1 2 2 0 0 0 3 0 0 0 2 0 3 1 0 2 2 2 2 0 3 0 2 0
2024-03-29 09:36:18.734957: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: INVALID_ARGUMENT: Received a label value of 3 which is outside the valid range of [0, 3).  Label values: 2 3 1 0 1 2 2 2 1 2 2 0 0 0 3 0 0 0 2 0 3 1 0 2 2 2 2 0 3 0 2 0
     [[{{function_node __inference_one_step_on_data_2794}}{{node compile_loss/sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits}}]]

InvalidArgumentError                      Traceback (most recent call last)
Input In [15], in <cell line: 1>()
----> 1 history = model.fit(
      2     train_dataset,
      3     validation_data=val_dataset,
      4     epochs=10,
      5     steps_per_epoch=train_size,
      6     validation_steps=val_size
      7 )

File ~/anaconda3/lib/python3.9/site-packages/keras/src/utils/traceback_utils.py:122, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    119     filtered_tb = _process_traceback_frames(e.__traceback__)
    120     # To get the full stack trace, call:
    121     # `keras.config.disable_traceback_filtering()`
--> 122     raise e.with_traceback(filtered_tb) from None
    123 finally:
    124     del filtered_tb

File ~/anaconda3/lib/python3.9/site-packages/tensorflow/python/eager/execute.py:53, in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     51 try:
     52   ctx.ensure_initialized()
---> 53   tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
     54                                       inputs, attrs, num_outputs)
     55 except core._NotOkStatusException as e:
     56   if name is not None:

InvalidArgumentError: Graph execution error:

Detected at node compile_loss/sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits defined at (most recent call last):
.
.
.

我查看了这个问题,但找不到答案,如何解决:InvalidArgumentError: 图执行错误?

0 个回答

暂无回答

撰写回答