已纠正的NADAM在Keras实施
keras-rnadam的Python项目详细描述
keras rnadam
在校正后的ADAM中使用Nesterov加速geadient而不是动量
安装
pip install keras_rnadam
用法
importkerasimportnumpyasnpfromkeras_rnadamimportRNAdam# Build toy model with RNAdam optimizermodel=keras.models.Sequential()model.add(keras.layers.Dense(input_shape=(17,),units=3))model.compile(RNAdam(),loss='mse')# Generate toy datax=np.random.standard_normal((4096*30,17))w=np.random.standard_normal((17,3))y=np.dot(x,w)# Fitmodel.fit(x,y,epochs=5)
使用预热
fromkeras_nradamimportRNAdamRNAdam(total_step=10000,warmup_proportion=0.1,min_lr=1e-5)