Pytorch研究者元学习框架
learn2learn的Python项目详细描述
learn2learn是一个用于元学习实现的pytorch库。 它是在first PyTorch Hackathon时期发展起来的。编辑:L2L很幸运赢得了Hackathon!
Notelearn2learn正在积极开发中,许多东西正在崩溃
安装
pip install learn2learn
API演示
importlearn2learnasl2lmnist=torchvision.datasets.MNIST(root="/tmp/mnist",train=True)task_generator=l2l.data.TaskGenerator(mnist,ways=3)model=Net()maml=l2l.MAML(model,lr=1e-3,first_order=False)opt=optim.Adam(maml.parameters(),lr=4e-3)foriterationinrange(num_iterations):learner=maml.new()# Creates a clone of modeltask=task_generator.sample(shots=1)# Fast adaptforstepinrange(adaptation_steps):error=compute_loss(task)learner.adapt(error)# Compute validation lossvalid_task=task_generator.sample(shots=1,classes_to_sample=task.sampled_classes)valid_error=compute_loss(valid_task)# Take the meta-learning stepopt.zero_grad()valid_error.backward()opt.step()
变更日志
以下更改日志主要用于hackathon期间。
2019年8月12日
- 基本实现了maml,fomaml,meta-sgd。
- 分类任务的任务生成器代码。
- RL的环境。
- MAML-A2C和MAML-PPO的小规模示例