一个友好的python包,用于只基于NumPy的Keras超参数调优。
keras-hypetune的Python项目详细描述
keras催眠
一个友好的python包,用于仅基于NumPy调整Keras超参数。在
概述
一个非常简单的包装器,用于快速Keras超参数优化。kerashypetune让你不用学习新的语法就可以使用keras的强大功能。你所需要的只是创建一个python字典,在那里为实验设置参数边界,并在一个可调用函数中定义Keras模型(任何格式:函数式或顺序式)。在
defget_model(param):model=Sequential()model.add(Dense(param['unit_1'],activation=param['activ']))model.add(Dense(param['unit_2'],activation=param['activ']))model.add(Dense(1))model.compile(optimizer=Adam(learning_rate=param['lr']),loss='mse',metrics=['mae'])returnmodel
使用Keras提供的回调可以很容易地跟踪优化过程。在该过程的最后,您可以访问查询keras hypetune searcher所需的所有内容。最佳解决方案可以自动保存在适当的位置。在
安装
^{pr2}$Tensorflow和Keras不是必需的要求。keras hypetune是专门为keras公司使用TensorFlow 2.0。GPU的使用是正常可用的。在
固定验证集
此优化模式在固定的验证集上执行优化。参数组合总是在同一组数据上计算。在这种情况下,允许使用Keras接受的任何类型的输入数据格式。在
KerasGridSearch
创建并计算所有传递的参数组合。在
param_grid={'unit_1':[128,64],'unit_2':[64,32],'lr':[1e-2,1e-3],'activ':['elu','relu'],'epochs':100,'batch_size':512}kgs=KerasGridSearch(get_model,param_grid,monitor='val_loss',greater_is_better=False)kgs.search(x_train,y_valid,validation_data=(x_valid,y_valid))
KerasRandomSearch
只创建和计算随机参数组合。在
尝试的参数组合数由nˉiter给出。如果所有参数以列表形式显示,则执行不替换的采样。如果至少有一个参数作为分布(从scipy.stats公司随机变量),使用替换抽样法。在
param_grid={'unit_1':[128,64],'unit_2':stats.randint(32,128),'lr':stats.uniform(1e-4,0.1),'activ':['elu','relu'],'epochs':100,'batch_size':512}krs=KerasRandomSearch(get_model,param_grid,monitor='val_loss',greater_is_better=False,n_iter=15,sampling_seed=33)krs.search(x_train,y_valid,validation_data=(x_valid,y_valid))
交叉验证
这种调优方式使用交叉验证方法来操作优化。可用的CV策略与scikit learn splitter类提供的相同。参数组合以褶皱的平均得分为基础进行评估。在这种情况下,只允许使用numpy数组数据。对于涉及多输入/输出的任务,可以像普通Keras一样将数组包装成list或dict。在
KerasGridSearchCV
创建并计算所有传递的参数组合。在
param_grid={'unit_1':[128,64],'unit_2':[64,32],'lr':[1e-2,1e-3],'activ':['elu','relu'],'epochs':100,'batch_size':512}cv=KFold(n_splits=3,random_state=33,shuffle=True)kgs=KerasGridSearchCV(get_model,param_grid,cv=cv,monitor='val_loss',greater_is_better=False)kgs.search(X,y)
KerasRandomSearchCV
只创建和计算随机参数组合。在
尝试的参数组合数由nˉiter给出。如果所有参数以列表形式显示,则执行不替换的采样。如果至少有一个参数作为分布(从scipy.stats公司随机变量),使用替换抽样法。在
param_grid={'unit_1':[128,64],'unit_2':stats.randint(32,128),'lr':stats.uniform(1e-4,0.1),'activ':['elu','relu'],'epochs':100,'batch_size':512}cv=KFold(n_splits=3,random_state=33,shuffle=True)krs=KerasRandomSearchCV(get_model,param_grid,cv=cv,monitor='val_loss',greater_is_better=False,n_iter=15,sampling_seed=33)krs.search(X,y)
媒体
用于:
- 项目
标签: