pytorch ai任务中用于可视化和存储管理的python包。
torchtracer的Python项目详细描述
火炬手
torchtracer
是pytorch ai任务中可视化和存储管理的工具包。
开始
需要Pythorch
这个工具是为pytorch ai任务开发的。因此,当然需要火把。
安装
您可以使用pip
来安装torchtracer
。
pip install torchtracer
如何使用?
导入torchtracer
fromtorchtracerimportTracer
创建Tracer
的实例
假设根目录是./checkpoints
,当前任务id是lmmnb
。
避免弄乱工作目录,您应该手动创建根目录。
tracer=Tracer('checkpoints').attach('lmmnb')
此步骤将创建一个目录checkpoints
,其中是当前ai任务的目录lmmnb
。
另外,您可以在没有任务id的情况下调用.attach()
。datetime将用作任务id。
tracer=Tracer('checkpoints').attach()
保存配置
原始配置应该是这样的dict
:
# `net` is a defined nn.Moduleargs={'epoch_n':120,'batch_size':10,'criterion':nn.MSELoss(),'optimizer':torch.optim.RMSprop(net.parameters(),lr=1e-3)}
配置dict应该用torchtracer.data.Config
cfg=Config(args)tracer.store(cfg)
此步骤将在./checkpoints/lmmnb/
中创建config.json
,其中包含如下json信息:
{"epoch_n":120,"batch_size":10,"criterion":"MSELoss","optimizer":{"lr":0.001,"momentum":0,"alpha":0.99,"eps":1e-08,"centered":false,"weight_decay":0,"name":"RMSprop"}}
记录
在训练迭代期间,您可以使用Tracer.log(msg, file)
打印任何想要的信息。
如果未指定file
,则将msg
输出到./checkpoints/lmmnb/log
。否则,它将是./checkpoints/lmmnb/something.log
。
tracer.log(msg='Epoch #{:03d}\ttrain_loss: {:.4f}\tvalid_loss: {:.4f}'.format(epoch,train_loss,valid_loss),file='losses')
此步骤将在./checkpoints/lmmnb/
中创建日志文件losses.log
,其中包含如下日志:
Epoch #001 train_loss: 18.6356 valid_loss: 21.3882 Epoch #002 train_loss: 19.1731 valid_loss: 17.8482 Epoch #003 train_loss: 19.6756 valid_loss: 19.1418 Epoch #004 train_loss: 20.0638 valid_loss: 18.3875 Epoch #005 train_loss: 18.4679 valid_loss: 19.6304 ...
保存模型
模型对象应该用torchtracer.data.Model
如果未指定file
,它将生成模型文件model.txt
。否则,它将是somename.txt
tracer.store(Model(model),file='somename')
此步骤将创建两个文件:
- 说明:
somename.txt
Sequential Sequential( (0): Linear(in_features=1, out_features=6, bias=True) (1): ReLU() (2): Linear(in_features=6, out_features=12, bias=True) (3): ReLU() (4): Linear(in_features=12, out_features=12, bias=True) (5): ReLU() (6): Linear(in_features=12, out_features=1, bias=True) )
- 参数:
somename.pth
保存matplotlib图像
使用tracer.store(figure, file)
将matplotlib图形保存在images/
# assume that `train_losses` and `valid_losses` are lists of losses. # create figure manually.plt.plot(train_losses,label='train loss',c='b')plt.plot(valid_losses,label='valid loss',c='r')plt.title('Demo Learning on SQRT')plt.legend()# save figure. remember to call `plt.gcf()`tracer.store(plt.gcf(),'losses.png')
此步骤将保存表示损耗曲线的png文件losses.png
。
各时期的进度条
使用tracer.epoch_bar_init(total)
初始化进度条。
tracer.epoch_bar_init(epoch_n)
使用tracer.epoch_bar.update(n=1, **params)
更新进度条的后缀。
tracer.epoch_bar.update(train_loss=train_loss,valid_loss=train_loss)
(THIS IS A DEMO) Tracer start at /home/oidiotlin/projects/torchtracer/checkpoints Tracer attached with task: rabbit Epoch: 100%|█████████| 120/120 [00:02<00:00, 41.75it/s, train_loss=0.417, valid_loss=0.417]
不要忘记调用tracer.epoch_bar.close()
来完成该栏。
贡献
如果您喜欢此项目,欢迎使用“拉取请求”和“创建问题”。