如何绘制AUC和ROC,同时使用fit帴u generator和evaluate帴生成器来训练我的网络?

2024-04-24 06:00:08 发布

您现在位置:Python中文网/ 问答频道 /正文

我使用生成器来训练和预测我的数据分类。下面是ImageDataGenerator的一个示例

from keras.preprocessing.image import ImageDataGenerator

batch_size = 16

train_datagen = ImageDataGenerator(
        rescale=1./255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)

test_datagen = ImageDataGenerator(rescale=1./255)


train_generator = train_datagen.flow_from_directory(
        'data/train',  # this is the target directory
        target_size=(150, 150),  
        batch_size=batch_size,
        class_mode='binary') 

validation_generator = test_datagen.flow_from_directory(
        'data/validation',
        target_size=(150, 150),
        batch_size=batch_size,
        class_mode='binary')


model.fit_generator(
        train_generator,
        steps_per_epoch=2000 // batch_size,
        epochs=50,
        validation_data=validation_generator,
        validation_steps=800 // batch_size)
model.save_weights('first_try.h5')  # always save your weights after training or during training

我的问题是,当我使用fit_generator时,如何创建AUC and ROC?在


Tags: fromtesttargetdatasizebatchrangetrain
1条回答
网友
1楼 · 发布于 2024-04-24 06:00:08

我认为在这种情况下,最好的办法是将AUC定义为一个新的指标。为此,您必须在tensorflow中定义度量(我假设您使用的是tensorflow后端)。在

我以前曾尝试过一种方法(但是,我不记得我测试过结果的正确性)如下:

def as_keras_metric(method):
    """
    This is taken from:
    https://stackoverflow.com/questions/45947351/how-to-use-tensorflow-metrics-in-keras/50527423#50527423
    """
    @functools.wraps(method)
    def wrapper(*args, **kwargs):
        """ Wrapper for turning tensorflow metrics into keras metrics """
        value, update_op = method(*args, **kwargs)
        tf.keras.backend.get_session().run(tf.local_variables_initializer())
        with tf.control_dependencies([update_op]):
            value = tf.identity(value)
        return value
    return wrapper

然后在编译模型时定义度量:

^{pr2}$

虽然这会给出一些数字,但我还没有弄清楚这些数字是否正确。如果你能测试这个,并且它给出了正确的结果,或者没有,请告诉我,我会很感兴趣的。在

解决此问题的第二种方法是使用callback class并至少定义on_epoch_end函数,然后可以从那里调用sklearnroc_auc_score,并打印出来或保存到日志中。在

但是,到目前为止,我发现,您需要通过__init__向它提供训练数据,因此对于生成器,您需要确保回调的生成器提供的数据与模型的拟合生成器相同。另一方面,对于验证生成器,可以使用self.validation_data从回调类访问它,这与提供给fit_generator的相同。在

相关问题 更多 >