支援火把检查站
pytorchcheckpoint的Python项目详细描述
火把检查站
此包支持保存和加载pytorch训练检查点。它在尝试从上一步恢复模型培训时非常有用,并且在处理spot实例或尝试复制结果时非常方便。
一个模型不仅保存了它的权重,就像人们在以后的推理中那样,还保存了模型的整个状态,包括优化器状态和参数。
此外,它还允许保存在训练时生成的度量和其他值,例如准确性和损失值。这样就可以根据过去的值重新创建学习曲线,并在培训过程中继续更新它们。
先决条件
使用python 3.7.3开发,但应与以前的python版本兼容。
pip install torch==1.1.0 torchvision==0.3.0
安装
pip install pytorchcheckpoint
用法
frompytorchcheckpoint.checkpointimportCheckpointHandlercheckpoint_handler=CheckpointHandler()
存储一般值
checkpoint_handler.store_var(var_name='num_of_classes',value=1000)
读取一般值
num_of_classes=checkpoint_handler.get_var(var_name='num_of_classes')
存储每个历元/迭代的值和度量。例如,损失值:
checkpoint_handler.store_running_var(var_name='loss',iteration=0,value=1.0)checkpoint_handler.store_running_var(var_name='loss',iteration=1,value=0.9)checkpoint_handler.store_running_var(var_name='loss',iteration=2,value=0.8)
读取epoch/迭代的存储值
loss=checkpoint_handler.get_running_var(var_name='loss',iteration=0)
存储每组值和度量:每个历元/迭代的序列/有效/测试。例如,列车的top1值和有效集:
checkpoint_handler.store_running_var_with_header(header='train',var_name='top1',iteration=0,value=80)checkpoint_handler.store_running_var_with_header(header='train',var_name='top1',iteration=1,value=85)checkpoint_handler.store_running_var_with_header(header='train',var_name='top1',iteration=2,value=90)checkpoint_handler.store_running_var_with_header(header='train',var_name='top1',iteration=3,value=91)checkpoint_handler.store_running_var_with_header(header='valid',var_name='top1',iteration=0,value=70)checkpoint_handler.store_running_var_with_header(header='valid',var_name='top1',iteration=1,value=75)checkpoint_handler.store_running_var_with_header(header='valid',var_name='top1',iteration=2,value=80)checkpoint_handler.store_running_var_with_header(header='valid',var_name='top1',iteration=3,value=85)
读取每组存储值:列车/有效/历元测试/迭代
loss=checkpoint_handler.get_running_var_with_header(header='train',var_name='loss',iteration=0)
保存检查点:
importtorchvision.modelsasmodelsfromtorchimportoptimcheckpoint_handler.store_running_var(var_name='loss',iteration=0,value=1.0)model=models.resnet18()optimizer=optim.SGD(model.parameters(),lr=0.001,momentum=0.9)path2save='/tmp'checkpoint_path=checkpoint_handler.generate_checkpoint_path(path2save=path2save)checkpoint_handler.save_checkpoint(checkpoint_path=checkpoint_path,iteration=25,model=model,optimizer=optimizer)
加载检查点:
checkpoint_path='<checkpoint_path>'checkpoint_handler=checkpoint_handler.load_checkpoint(checkpoint_path)
num_of_classes=checkpoint_handler.get_var(var_name='num_of_classes')
存储每个历元/迭代的值和度量。例如,损失值:
checkpoint_handler.store_running_var(var_name='loss',iteration=0,value=1.0)checkpoint_handler.store_running_var(var_name='loss',iteration=1,value=0.9)checkpoint_handler.store_running_var(var_name='loss',iteration=2,value=0.8)
读取epoch/迭代的存储值
loss=checkpoint_handler.get_running_var(var_name='loss',iteration=0)
存储每组值和度量:每个历元/迭代的序列/有效/测试。例如,列车的top1值和有效集:
checkpoint_handler.store_running_var_with_header(header='train',var_name='top1',iteration=0,value=80)checkpoint_handler.store_running_var_with_header(header='train',var_name='top1',iteration=1,value=85)checkpoint_handler.store_running_var_with_header(header='train',var_name='top1',iteration=2,value=90)checkpoint_handler.store_running_var_with_header(header='train',var_name='top1',iteration=3,value=91)checkpoint_handler.store_running_var_with_header(header='valid',var_name='top1',iteration=0,value=70)checkpoint_handler.store_running_var_with_header(header='valid',var_name='top1',iteration=1,value=75)checkpoint_handler.store_running_var_with_header(header='valid',var_name='top1',iteration=2,value=80)checkpoint_handler.store_running_var_with_header(header='valid',var_name='top1',iteration=3,value=85)
读取每组存储值:列车/有效/历元测试/迭代
loss=checkpoint_handler.get_running_var_with_header(header='train',var_name='loss',iteration=0)
保存检查点:
importtorchvision.modelsasmodelsfromtorchimportoptimcheckpoint_handler.store_running_var(var_name='loss',iteration=0,value=1.0)model=models.resnet18()optimizer=optim.SGD(model.parameters(),lr=0.001,momentum=0.9)path2save='/tmp'checkpoint_path=checkpoint_handler.generate_checkpoint_path(path2save=path2save)checkpoint_handler.save_checkpoint(checkpoint_path=checkpoint_path,iteration=25,model=model,optimizer=optimizer)
加载检查点:
checkpoint_path='<checkpoint_path>'checkpoint_handler=checkpoint_handler.load_checkpoint(checkpoint_path)