schedule: a function that takes an epoch index (integer, indexed from 0) and current learning rate (float) as inputs and returns a new learning rate as output (float).
To implement your own schedule object, you should implement the
call method, which takes a step argument (scalar integer tensor, the current training step count).
tf.keras.callbacks.LearningRateScheduler()
和tf.keras.optimizers.schedules.LearningRateSchedule()
都提供相同的功能,即在训练模型时实现学习速率衰减一个明显的区别可能是
tf.keras.callbacks.LearningRateScheduler
在其构造函数中接受一个函数,如文档中所述函数将返回给定当前历元索引的学习速率。要实现各种类型的LR衰减,如指数衰减、多项式衰减等,您需要自己用这个
schedule
方法对它们进行编码另一方面,
tf.keras.optimizers.schedules.LearningRateSchedule()
是一个高级类。tf.keras.optimizers.schedules.*
中包含的其他类型的衰变,如PolynomialDecay
或InverseTimeDecay
继承此类。因此,该模块提供了ML中常用的内置LR衰减方法。此外,要实现自定义LR衰减,您的类需要继承tf.keras.optimizers.schedules.LearningRateSchedule()
并重写__call__
和__init__
等方法,如文档中所述结论:
如果您想使用一些内置的LR衰减,请使用
tf.keras.optimizers.schedules.*
模块,即该模块中提供的LR衰减如果您需要一个简单的自定义LR衰减,它只需要epoch索引作为参数,请使用
tf.keras.callbacks.LearningRateScheduler
如果自定义LR Decay需要的参数不仅仅是历元索引,那么创建一个新类并继承
tf.keras.optimizers.schedules.LearningRateSchedule
相关问题 更多 >
编程相关推荐