一个通过git管理tensorflow实验并减少样板文件的简单库。与tf 1.x兼容
tfExperiment的Python项目详细描述
t实验
一个通过git管理tensorflow实验并减少样板文件的简单库。 与tf 1.x兼容
用法
这个库依赖于git来管理实验。每个实验都应该是一个惟一的git分支,如果不给出实验的名称,那么它就是当前的git分支。
experiment=tfExperiment.Experiment(finalizeGraph=False)experiment.saveGraph()# output# > graph location ======================================># > tensorboard --logdir output\experimentName\graph# with withwithexperiment.trainingSession(epochs=125,saveAfter=2,testAfter=2)asts:ts.saveGraph()# function to save the graphts.trainCallback=runTrainingCallbackts.testCallback=runTestCallback# as functionexperiment.train(runTrainingCallback)experiment.test(runTestCallback)
API
__init__(name = None, finalizeGraph = False, location = os.path.join(os.getcwd(), 'output'))
name: string
:实验的名称,如果没有提供名称,则将使用当前git分支的名称。finalizeGraph: bool
:完成图形。attention我没怎么尝试过这个功能。location: string
:实验结果保存在与名称相同的文件夹中的绝对路径
train(trainCallback, epochs = 1, saveModelAfter = 2, saveGraph = False, testCallback = None, testAfter = 0)
运行培训并验证/测试模型
trainCallback: function
:在每个历元运行的函数。这应该包含您的循环,其中包含要为每个批次执行的培训操作。训练回调可以采用两个参数:session(current tf.session)和env(如果使用env,则应使用确切的名称)实验环境,可以访问timer和datasaver等功能。epochs: integer
:要运行的阶段数,即调用trainincallbacks的次数。注意:实验对象会跟踪到目前为止运行的epoch的数量,因此如果再次调用experiment.train
,epoch的数量将从上一个epoch的数量继续增长。saveModelAfter: integer
:运行n个阶段后保存模型。这只考虑当前运行。saveGraph: bool
:是否应该在当前运行时保存图形。testCallback: function
:调用以测试/验证当前网络的函数。类似于TrainCallback。testAfter: integer
:运行n个阶段后测试模型。这只考虑当前运行。
test(testCallback)
对模型运行一次测试/验证
testCallback: function
:调用以测试/验证当前网络的函数。类似于TrainCallback。
env: Box object
env对象包含
env.training.currentEpoch: integer
:自实例初始化以来的时间段数。env.training.currentEpoch: integer
:自实例初始化以来的时间段数。env.training.dataSavePath: path string
:如果在训练期间使用datasaver,则使用数据的路径。env.training.dataSaver: dataSaver Instance
:用于训练到训练文件的datasaver实例(用env.training.datasavepath初始化)。env.testing.dataSavePath: path string
如果在测试期间使用datasaver,则使用数据的路径。env.testing.dataSaver: dataSaver Instance
:用于测试训练文件的datasaver实例(用env.testing.datasavepath初始化)。
提议的新api
defTrainExperiment(Experiment):def__init__(self,constructor,...):#someconfig#self.nrTotEpochs#self.epochsToValidateAfter#...defbeforeEpochdefafterEpochdefbeforeSavedefafterSavedefbeforeTestdefafterTestdefbeforeIterationdefafterIterationdeftrain(session,data,dataProvider=None):return0#trainingLoopPerSessiondefvalidate(session,data,dataProvider=None):return0#trainingLoopPerSessionexperiment(TrainExperiment)