Pythorch模块训练器
pytorchtrainer的Python项目详细描述
Pythorch培训师
您是否厌倦了编写那些相同的epoch和数据加载程序循环来训练pytorch模块? 别再看了,Pythorch Trainer是一个库,它隐藏了所有那些应该是Pythorch本机的枯燥的训练代码行。
您还将受益于以下功能:
- 早停:停滞一段时间后停止训练
- 检查点:定期保存模型和估计器
- CSV文件写入程序输出日志
- 有几个指标可用:所有默认Pythorch损失函数、精度、MAE
- 控制台上的进度条
- sigint处理:handle ctrl-c
- 模型的数据类型(
float32
,float64
)
示例
代码示例可以在example folder中找到。
下面是一个简单的示例:
importtorchimportpytorchtrainerasptt# Your usual model, optimizer, loss function and data loadersmodel=MyModel()optimizer=torch.optim.Adam(self.model.parameters(),lr=1e-3)criterion=torch.nn.MSELoss()train_loader=MyTrainDataloader()validation_loader=MyValidationDataloader()# instantiate a default trainertrainer=ptt.create_default_trainer(model,optimizer,criterion)# optionally save a checkpoint after every 10 epochstrainer.register_post_epoch_callback(ptt.checkpoint.SaveCheckpointCallback(save_every=10))# optionally compute validation loss after every epochvalidation_callback=ptt.callback.ValidationCallback(validation_loader,ptt.metric.Loss(criterion),validate_every=1)trainer.register_post_epoch_callback(validation_callback)# optionally save training and validation loss after every iteration using default save directorytrainer.register_post_iteration_callback(ptt.callback.CsvWriter(save_every=1,extra_header=[validation_callback.state_attribute_name],callback=lambdastate:[state.get(validation_callback.state_attribute_name)]))# run the trainingtrainer.train(train_loader,max_epochs=100)
依赖关系
- Python>;3.5
- pytorch 1.0.1(安装说明来自官方PyTorch website)
贡献
请随时提交问题或请求。但是在你阅读之前请先阅读contributing guidelines