专注于建立和优化pytorch模型,而不是训练循环
torchtrainer的Python项目详细描述
托希特雷纳
Pythorch模型训练在不失去控制的情况下变得更简单。专注于优化你的模型!概念很大程度上受到了这个很棒的项目torchsample和Keras的启发。 此外,除了应用epoch回调之外,它还允许在长epoch持续时间内每次经过特定数量的批处理(迭代)之后调用回调。
功能
Torchtrainer
- 日志工具
- 指标
- 可视化效果
- 学习速率调度器
- 检查点 用于数据输入的灵活性< < > >
- 每…批次
用法
安装
pip install torchtrainer
示例
fromtorchimportnnfromtorch.optimimportSGDfromtorchtrainerimportTorchTrainerfromtorchtrainer.callbacksimportVisdomLinePlotter,ProgressBar,VisdomEpoch,Checkpoint,CSVLogger, \ EarlyStoppingEpoch,ReduceLROnPlateauCallbackfromtorchtrainer.metricsimportBinaryAccuracymetrics=[BinaryAccuracy()]train_loader=...val_loader=...model=...loss=nn.BCELoss()optimizer=SGD(model.parameters(),lr=0.001,momentum=0.9)# Setup Visdom Environment for your modlplotter=VisdomLinePlotter(env_name=f'Model {11}')# Setup the callbacks of your choicecallbacks=[ProgressBar(log_every=10),VisdomEpoch(plotter,on_iteration_every=10),VisdomEpoch(plotter,on_iteration_every=10,monitor='binary_acc'),CSVLogger('test.csv'),Checkpoint('./model'),EarlyStoppingEpoch(min_delta=0.1,monitor='val_running_loss',patience=10),ReduceLROnPlateauCallback(factor=0.1,threshold=0.1,patience=2,verbose=True)]trainer=TorchTrainer(model)# function to transform batch into inputs to your model and y_true values# if your model accepts multiple inputs, just put all inputs into a tuple (input1, input2), y_truedeftransform_fn(batch):inputs,y_true=batchreturninputs,y_true.float()# prepare your trainer for trainingtrainer.prepare(optimizer,loss,train_loader,val_loader,transform_fn=transform_fn,callbacks=callbacks,metrics=metrics)# train your modelresult=trainer.train(epochs=10,batch_size=10)
回拨
记录器
CSVLogger
CSVLoggerIteration
ProgressBar
可视化和日志记录
VisdomEpoch
优化器
ReduceLROnPlateauCallback
StepLRCallback
正则化
EarlyStoppingEpoch
EarlyStoppingIteration
检查点
Checkpoint
CheckpointIteration
指标
CSVLogger
CSVLoggerIteration
ProgressBar
VisdomEpoch
优化器
ReduceLROnPlateauCallback
StepLRCallback
正则化
EarlyStoppingEpoch
EarlyStoppingIteration
检查点
Checkpoint
CheckpointIteration
指标
目前只实现了BinaryAccuracy
。要实现其他度量,请使用抽象基度量类torchtrainer.metrics.metric.Metric
。
待办事项
- 更多测试
- 指标