用于训练、跟踪和保存pytorch模型的工具包
torch-scope的Python项目详细描述
焊枪范围
一个用于培训pytorch模型的工具包,它有三个功能:
- 跟踪环境、依赖关系、实现和检查点;
- 提供带有两个处理程序(to
std
和file
)的记录器包装器 - 支持自动设备选择;
- 提供张力板包装器;
- 提供一个电子表格编写器,用于自动汇总注释和结果;
我们处于早期发布测试阶段。期待一些冒险和粗糙的边缘。
快速链接
安装
要通过pypi安装:
pip install torch-scope
从源代码生成:
pip install git+https://github.com/LiyuanLucasLiu/Torch-Scope
或
git clone https://github.com/LiyuanLucasLiu/Torch-Scope.git
cd Torch-Scope
python setup.py install
用法
下面是一个例子,请阅读文档以获得详细的api解释。
- 在服务器中设置git并将所有源文件添加到git
- 使用tensorboard跟踪模型统计信息(tensorboard--logdir path/log/--port 35;)
from torch_scope import wrapper
...
logger = logging.getLogger(__name__)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint_path', type=str, ...)
parser.add_argument('--name', type=str, ...)
parser.add_argument('--gpu', type=str, ...)
...
args = parser.parse_args()
pw = wrapper(os.path.join(args.checkpoint_path, args.name), name = args.log_dir, enable_git_track = False)
# Or if the current folder is binded with git, you can turn on the git tracking as below
# pw = wrapper(os.path.join(args.checkpoint_path, args.name), name = args.log_dir, enable_git_track = True)
# if you properly set the path to credential_path and want to use spreadsheet writer, turn on sheet tracking as below
# pw = wrapper(os.path.join(args.checkpoint_path, args.name), name = args.log_dir, \
# enable_git_track=args.git_tracking, sheet_track_name=args.spreadsheet_name, \
# credential_path="/data/work/jingbo/ll2/Torch-Scope/torch-scope-8acf12bee10f.json")
gpu_index = pw.auto_device() if 'auto' == args.gpu else int(args.gpu)
device = torch.device("cuda:" + str(gpu_index) if gpu_index >= 0 else "cpu")
pw.save_configue(args) # dump the config to config.json
# if the spreadsheet writer is enabled, you can add a description about the current model
# pw.add_description(args.description)
logger.info(str(args)) # would be plotted to std & file if level is 'info' or lower
...
batch_index = 0
for index in range(epoch):
...
for instance in ... :
loss = ...
tot_loss += loss.detach()
loss.backward()
if batch_index % ... = 0:
pw.add_loss_vs_batch({'loss': tot_loss / ..., ...}, batch_index, False)
pw.add_model_parameter_stats(model, batch_index, save=True)
optimizer.step()
pw.add_model_update_stats(model, batch_index)
tot_loss = 0
else:
optimizer.step()
batch_index += 1
dev_score = ...
pw.add_loss_vs_batch({'dev_score': dev_score, ...}, index, True)
if dev_score > best_score:
pw.save_checkpoint(model, optimizer, is_best = True)
best_score = dev_score
else:
pw.save_checkpoint(model, optimizer, is_best = False)
高级用法
自动装置
git跟踪
电子表格记录
与以下帐户共享电子表格torch-scope@torch-scope.iam.gserviceaccount.com
。并使用其名称访问表。