具有其他输入形状和图像净重的VGG16

2024-05-15 02:25:15 发布

您现在位置:Python中文网/ 问答频道 /正文

我不熟悉VGG16这样的型号。我一直在搜索关于这个模型的信息,但我仍然对它有怀疑。我有10000张不同大小的图片来训练这个模型(2个类),所以我决定使用86x86的图片大小,因为计算上的限制,它几乎是每个图片大小的平均值。所以我这样做了:

base_model16 = VGG16(weights='imagenet', include_top=False, input_shape=(86,86,3)) 

至于发电机:

datagen = ImageDataGenerator(preprocessing_function=preprocess_vgg16) 

train_generator = datagen.flow_from_directory(path_train,
                                                    target_size=(86,86),
                                                    color_mode='rgb',
                                                    batch_size = 128,
                                                    class_mode='categorical',
                                                    shuffle=True) 

我读到VGG16是用224x224训练的,我知道我们可以使用其他尺寸,但有人能确认我做得对吗?因为我使用的是imagenet weights和preprocess_vgg16,它使用的是224x224。 对不起,如果以前有人问过这个问题,请帮我理解

多谢各位


Tags: 模型信息sizemode图片trainimagenet型号
1条回答
网友
1楼 · 发布于 2024-05-15 02:25:15

您必须修改Vgg模型,因为它设计用于对1000幅图像进行分类。设置include_top=False将删除具有1000个神经元的模型顶层。现在我们需要包括一个层,其中将有2个神经元。下面的代码将实现这一点。注意,在VGG模型的参数中,我设置了pooling='max'。这使得Vgg模型的输出成为一个向量,可以用作密集层的输入

base_model=tf.keras.applications.VGG16( include_top=False, input_shape=(86,86,3), 
                                        pooling='max', weights='imagenet' ) 
x=base_model.output
output=Dense(2, activation='softmax')(x)
model=Model(inputs=base_model.input, outputs=output)
model.compile(Adam(lr=.001), loss='categorical_crossentropy', metrics=['accuracy') 

顺便说一句,我不喜欢使用VGG16。它有大约4000万个可转换的参数,因此计算开销大,导致训练时间长。我更喜欢使用MobileNet模型,该模型只有大约400万个可训练参数,而且精度也差不多。要使用MobileNet模型,只需使用这行代码,而不是Vgg模型的代码。注:我将图像形状设置为(128128,3),因为有一个版本的mobilenet权重在imagenet上使用128 X 128图像进行训练,它将自动下载并帮助模型更快收敛。但如果您愿意,可以使用86 X86。因此,在您的列车中,发电机组的目标是(128128)。此外,在ImageDataGenerator中,代码预处理函数=预处理函数vgg16仍应适用于Mobilenet模型,因为我认为它与keras.applications.Mobilenet.preprocess输入相同。我相信它们都只是将像素重新缩放到-1和+1之间

base_model=tf.keras.applications.mobilenet.MobileNet( include_top=False, 
           input_shape=(128,128,3), pooling='max', weights='imagenet',dropout=.4)

相关问题 更多 >

    热门问题