获取属于特定类的数据子集

2024-04-23 20:38:28 发布

您现在位置:Python中文网/ 问答频道 /正文

我试着只学习两门课,例如:CIFAR数据集中的猫和狗(作为火车集加载)。 我正在尝试使用以下代码来实现这一点:

def getIndices(d_targets, idx):
    lst=[]
    for j in idx:
        for (i,index) in enumerate(d_targets):
            if (index == j):
                lst.append(i)
    return lst    
labels_to_select = [3,5] #cat vs dog
trainset_subset_labels = getIndices(trainset.targets,labels_to_select)
trainset_2 = torch.utils.data.Subset(trainset,trainset_subset_labels)
trainloader = torch.utils.data.DataLoader(trainset_2, batch_size=batch_size,shuffle=True, 
num_workers=2)

列车组形状->;(50000,32,32,3)

所需列车组2形状->;(10000,32,32,3)

我得到的列车组2形状->;(50000,32,32,3)

我应该在trainset_2中得到一个更小的数据集,但这并没有发生。知道我做错了什么吗