swats算法的pytorch实现。
pytorch-swats的Python项目详细描述
从ADAM切换到SGD
Wilson et al. (2018)表明,“自适应方法发现的解比sgd更差(通常明显更差),即使这些解具有更好的训练性能。这些结果表明,实践者应该重新考虑使用自适应方法来训练神经网络。“
“swats来自2018年iclr的高分论文Keskar & Socher (2017),该方法建议自动从adam切换到sgd,以获得更好的泛化性能。算法本身的思想非常简单。它使用adam,尽管最小的调整效果很好,但是在学习到某个阶段之后,它被sgd接管。“
用法
直接从这个git存储库使用pip或使用以下命令之一从pypi安装包是很简单的。
pip install git+https://github.com/Mrpatekful/swats
pip install pytorch-swats
安装后swats可以用作任何其他torch.optim.Optimizer
。下面的代码片段是如何使用算法的简单概述。
importswatsoptimizer=swats.SWATS(model.parameters())data_loader=torch.utils.data.DataLoader(...)forepochinrange(10):forinputs,targetsindata_loader:# deleting the stored grad valuesoptimizer.zero_grad()outputs=model(inputs)loss=loss_fn(outputs,targets)loss.backward()# performing parameter updateoptimizer.step()