工具箱使使用Pythorch更容易。
torchtoolbox的Python项目详细描述
Pythorch工具箱
这是Pythorch的工具箱项目。旨在使编写pytorch代码更容易、可读和简洁。
你也可以把它看作是火炬手的辅助工具。它将包含您最常用的工具。安装
一个简单的安装方法是使用pip:
pip install torchtoolbox
如果要安装夜间版本(建议现在安装):
pip install -U git+https://github.com/deeplearningforfun/torch-toolbox.git@master
用法
工具箱主要有两部分:
- 其他工具使您更容易使用pytorch。 有些时尚作品不存在于火炬核心中。
工具
1.显示模型参数和触发器。
importtorchfromtorchtoolbox.toolsimportsummaryfromtorchvision.models.mobilenetimportmobilenet_v2model=mobilenet_v2()summary(model,torch.rand((1,3,224,224)))
下面是一些简短的输出。
Layer (type) Output Shape Params FLOPs(M+A) #
================================================================================
Conv2d-1 [1, 64, 112, 112] 9408 235225088
BatchNorm2d-2 [1, 64, 112, 112] 256 1605632
ReLU-3 [1, 64, 112, 112] 0 0
MaxPool2d-4 [1, 64, 56, 56] 0 0
... ... ... ...
Linear-158 [1, 1000] 1281000 2560000
MobileNetV2-159 [1, 1000] 0 0
================================================================================
Total parameters: 3,538,984 3.5M
Trainable parameters: 3,504,872
Non-trainable parameters: 34,112
Total flops(M) : 305,252,872 305.3M
Total flops(M+A): 610,505,744 610.5M
--------------------------------------------------------------------------------
Parameters size (MB): 13.50
2.公制集合
当我们训练一个模型时,我们通常需要计算一些指标,如精度(top1 acc)、损失等。 现在工具箱支持如下:
- 准确度:TOP-1 acc.
- topkaccuracy:topk-acc.
- numericalcost:这是一个支持
mean
、max
、min
计算类型的数字度量集合。
fromtorchtoolboximportmetric# define firsttop1_acc=metric.Accuracy(name='Top1 Accuracy')top5_acc=metric.TopKAccuracy(top=5,name='Top5 Accuracy')loss_record=metric.NumericalCost(name='Loss')# reset before usingtop1_acc.reset()top5_acc.reset()loss_record.reset()...model.eval()fordata,labelsinval_data:data=data.to(device,non_blocking=True)labels=labels.to(device,non_blocking=True)outputs=model(data)losses=Loss(outputs,labels)# update/recordtop1_acc.step(outputs,labels)top5_acc.step(outputs,labels)loss_record.step(losses)test_msg='Test Epoch {}: {}:{:.5}, {}:{:.5}, {}:{:.5}\n'.format(epoch,top1_acc.name,top1_acc.get(),top5_acc.name,top5_acc.get(),loss_record.name,loss_record.get())print(test_msg)
然后您可能会得到这样的输出
Test Epoch 101: Top1 Accuracy:0.7332, Top5 Accuracy:0.91514, Loss:1.0605
3.模型初始值设定项
现在工具箱支持XavierInitializer
和KaimingInitializer
。
fromtorchtoolbox.nn.initimportKaimingInitializermodel=XXXKaimingInitializer(model)
时装作品
1.labelsmoothingloss
fromtorchtoolbox.nnimportLabelSmoothingLoss# The num classes of your task should be defined.classes=10# LossLoss=LabelSmoothingLoss(classes,smoothing=0.1)...fori,(data,labels)inenumerate(train_data):data=data.to(device,non_blocking=True)labels=labels.to(device,non_blocking=True)optimizer.zero_grad()outputs=model(data)# just use as usual.loss=Loss(outputs,labels)loss.backward()optimizer.step()
2.cosinewarmupler
带有预热期的余弦lr调度器,有助于提高分类模型的acc。
fromtorchtoolbox.optimizerimportCosineWarmupLroptimizer=optim.SGD(...)# define scheduler# `batches_pre_epoch` means how many batches(times update/step the model) within one epoch.# `warmup_epochs` means increase lr how many epochs to `base_lr`.# you can find more details in file.lr_scheduler=CosineWarmupLr(optimizer,batches_pre_epoch,epochs,base_lr=lr,warmup_epochs=warmup_epochs)...fori,(data,labels)inenumerate(train_data):...optimizer.step()# remember to step/update status here.lr_scheduler.step()...
3.开关标准2D/3D
fromtorchtoolbox.nnimportSwitchNorm2d,SwitchNorm3d
像batchnorm2d/3d那样使用它。 更多详情请参考原始文件 Differentiable Learning-to-Normalize via Switchable NormalizationOpenSourse
4.swish激活
fromtorchtoolbox.nnimportSwish
就像雷卢一样使用它。 更多详情请参考原始文件 SEARCHING FOR ACTIVATION FUNCTIONS
5。前瞻优化器
包装优化器似乎比adam好。 Lookahead Optimizer: k steps forward, 1 step back
fromtorchtoolbox.optimizerimportLookaheadfromtorchimportoptimoptimizer=optim.Adam(...)optimizer=Lookahead(optimizer)
5。混合训练
训练分类模型的混合方法。 mixup: BEYOND EMPIRICAL RISK MINIMIZATION
fromtorchtoolbox.toolsimportmixup_data,mixup_criterion# set beta distributed parm, 0.2 is recommend.alpha=0.2fori,(data,labels)inenumerate(train_data):data=data.to(device,non_blocking=True)labels=labels.to(device,non_blocking=True)data,labels_a,labels_b,lam=mixup_data(data,labels,alpha)optimizer.zero_grad()outputs=model(data)loss=mixup_criterion(Loss,outputs,labels_a,labels_b,lam)loss.backward()optimizer.step()
6.切口
一种图像变换方法。 Improved Regularization of Convolutional Neural Networks with Cutout
fromtorchvisionimporttransformsfromtorchtoolbox.transformimportCutout_train_transform=transforms.Compose([transforms.RandomResizedCrop(224),Cutout(),transforms.RandomHorizontalFlip(),transforms.ColorJitter(0.4,0.4,0.4),transforms.ToTensor(),normalize,])
7.无衰减偏差
如果你训练一个大批量的模型,例如64K,你可能需要这个, Highly Scalable Deep Learning Training System with Mixed-Precision: Training ImageNet in Four Minutes
fromtorchtoolbox.toolsimportsplit_weightsfromtorchimportoptimmodel=XXXparameters=split_weights(model)optimizer=optim.SGD(parameters,...)