易于在pytorch中训练神经网络的高级库。
pytorch-argus的Python项目详细描述
argus
ARGUS是PyTorch神经网络训练的易用柔性库。
安装
来自PIP:
pip install pytorch-argus
来源:
git clone https://github.com/lRomul/argus
cd argus
python setup.py install
示例
简单的图像分类示例:
importtorchfromtorchimportnnimporttorch.nn.functionalasFfrommnist_utilsimportget_data_loadersfromargusimportModel,load_modelfromargus.callbacksimportMonitorCheckpoint,EarlyStopping,ReduceLROnPlateauclassNet(nn.Module):def__init__(self,n_classes,p_dropout=0.5):super().__init__()self.conv1=nn.Conv2d(1,10,kernel_size=5)self.conv2=nn.Conv2d(10,20,kernel_size=5)self.conv2_drop=nn.Dropout2d(p=p_dropout)self.fc1=nn.Linear(320,50)self.fc2=nn.Linear(50,n_classes)defforward(self,x):x=F.relu(F.max_pool2d(self.conv1(x),2))x=F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)),2))x=x.view(-1,320)x=F.relu(self.fc1(x))x=F.dropout(x,training=self.training)x=self.fc2(x)returnxclassMnistModel(Model):nn_module=Netoptimizer=torch.optim.SGDloss=torch.nn.CrossEntropyLossif__name__=="__main__":train_loader,val_loader=get_data_loaders()params={'nn_module':{'n_classes':10,'p_dropout':0.1},'optimizer':{'lr':0.01},'device':'cpu'}model=MnistModel(params)callbacks=[MonitorCheckpoint(dir_path='mnist',monitor='val_accuracy',max_saves=3),EarlyStopping(monitor='val_accuracy',patience=9),ReduceLROnPlateau(monitor='val_accuracy',factor=0.5,patience=3)]model.fit(train_loader,val_loader=val_loader,max_epochs=50,metrics=['accuracy'],callbacks=callbacks,metrics_on_train=True)delmodelmodel=load_model('mnist/model-last.pth')
与pytorch-cnn-finetune中的make_model
一起使用argus。
fromcnn_finetuneimportmake_modelfromargusimportModelclassCnnFinetune(Model):nn_module=make_modelparams={'nn_module':{'model_name':'resnet18','num_classes':10,'pretrained':False,'input_size':(256,256)},'optimizer':('Adam',{'lr':0.01}),'loss':'CrossEntropyLoss','device':'cuda'}model=CnnFinetune(params)
您可以找到其他示例here。
Kaggle解决方案
- 2019年免费音频标签的第一名解决方案(MEL频谱图,与APEX混合精确训练)
https://github.com/lRomul/argus-freesound - TGS盐鉴定挑战第14名解决方案(分段,平均老师)
https://github.com/lRomul/argus-tgs-salt - 第50位快速解决方案,绘制!涂鸦识别挑战(梯度积累,50米图像训练)
https://github.com/lRomul/argus-quick-draw - Kaggle空客船舶检测挑战第66位解决方案(实例分割)
https://github.com/OniroAI/Universal-segmentation-baseline-Kaggle-Airbus-Ship-Detection - 座头鲸识别的解决方案(度量学习:弧面,中心损失)
https://github.com/lRomul/argus-humpback-whale - vsb电源线故障检测(1d conv)解决方案
https://github.com/lRomul/argus-vsb-power