Pythorch模型状态保存和加载
torch-model-state的Python项目详细描述
火炬模型状态
pytorch模型状态保存和加载。
安装
需要Python3.6+。
pip install torch-model-state
用法
Python:
importboximporttorch_model_statefromtorch.optimimportSGDconfig={'type':'MobileNetV2'# need install torch-basic-models}model=box.factory(config=config,tag='model')optimizer=SGD(model.parameters(),lr=0.1)state=torch_model_state.to_state(model=model,config=config,optimizers=[optimizer])torch_model_state.save_state_file(state=state,file_path='checkpoint.sf')state=torch_model_state.load_state_file(file_path='checkpoint.sf',device='cpu')torch_model_state.from_state(state,model,[optimizer],device='cpu')
直接从状态文件(.sf)加载:
importtorch_model_statemodel=torch_model_state.load_model_from_state(file_path='checkpoint.sf',device='cpu')
客户端:
# show help torch-model-state -h #> usage: torch-model-state [-h] [--load_model] [--extra_import EXTRA_IMPORT]#> [--device DEVICE]#> state_file#>#> Viewer of PyTorch State File [.sf]#>#> positional arguments:#> state_file path of PyTorch state file#>#> optional arguments:#> -h, --help show this help message and exit#> --load_model load model and show#> --extra_import EXTRA_IMPORT#> import extra models#> --device DEVICE load device, cpu in default# view basic info of state file torch-model-state checkpoint.sf #> {#> "config": {#> "type": "MobileNetV2"#> },#> "info": null,#> "timestamp": "2019-04-27 22:42:55.345000"#> }# view & load Model torch-model-state checkpoint.sf --load_model #> {#> "config": {#> "type": "MobileNetV2"#> },#> "info": null,#> "timestamp": "2019-04-27 22:42:55.345000"#> }#> MobileNetV2(#> (blocks): Sequential(#> (0): Sequential(#> (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)#> (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)#> (2): InplaceReLU6(inplace)#> )#> ...# export to ONNX torch-model-state checkpoint.sf --export_onnx checkpoint.onnx