我试着只学习两门课,例如:CIFAR数据集中的猫和狗(作为火车集加载)。 我正在尝试使用以下代码来实现这一点:
def getIndices(d_targets, idx):
lst=[]
for j in idx:
for (i,index) in enumerate(d_targets):
if (index == j):
lst.append(i)
return lst
labels_to_select = [3,5] #cat vs dog
trainset_subset_labels = getIndices(trainset.targets,labels_to_select)
trainset_2 = torch.utils.data.Subset(trainset,trainset_subset_labels)
trainloader = torch.utils.data.DataLoader(trainset_2, batch_size=batch_size,shuffle=True,
num_workers=2)
列车组形状->;(50000,32,32,3)
所需列车组2形状->;(10000,32,32,3)
我得到的列车组2形状->;(50000,32,32,3)
我应该在trainset_2中得到一个更小的数据集,但这并没有发生。知道我做错了什么吗
目前没有回答
相关问题 更多 >
编程相关推荐