Pytorch中元学习的数据加载器
torchmeta的Python项目详细描述
托奇梅塔
在PyTorch中为少数镜头学习和元学习提供的扩展和数据加载程序的集合。这个包包含流行的元学习基准,完全兼容^{
示例
下面这个最小的示例演示如何使用torchmeta
为5-shot 5-way omniglot数据集创建数据加载器。数据加载器加载一批随机生成的任务。有关更多示例,请检查examples文件夹。
fromtorchmeta.datasetsimportOmniglotfromtorchmeta.transformsimportCategorical,ClassSplitterfromtorchvision.transformsimportResize,ToTensor,Composefromtorchmeta.utils.dataimportBatchMetaDataLoaderdataset=Omniglot('data',num_classes_per_task=5,transform=Compose([Resize(28),ToTensor()]),target_transform=Categorical(num_classes=5),meta_train=True,download=True)dataset=ClassSplitter(dataset,num_train_per_class=5,num_test_per_class=15)dataloader=BatchMetaDataLoader(dataset,batch_size=16,num_workers=4)forbatchindataloader:train_inputs,train_targets=batch['train']print('Train inputs shape: {0}'.format(train_inputs.shape))print('Train targets shape: {0}'.format(train_targets.shape))# Train inputs shape: torch.Size([16, 25, 1, 28, 28])# Train targets shape: torch.Size([16, 25])test_inputs,test_targets=batch['test']print('Test inputs shape: {0}'.format(test_inputs.shape))print('Test targets shape: {0}'.format(test_targets.shape))# Test inputs shape: torch.Size([16, 75, 1, 28, 28])# Test targets shape: torch.Size([16, 75])