Pythorch数据集:将整个数据集转换为NumPy

2024-03-29 15:32:07 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试将Torchvision MNIST训练和测试数据集转换为NumPy数组,但找不到实际执行转换的文档。在

我的目标是获取整个数据集并将其转换为单个NumPy数组,最好不要遍历整个数据集。在

我看过How do I turn a Pytorch Dataloader into a numpy array to display image data with matplotlib?,但它没有解决我的问题。在

所以我的问题是,利用torch.utils.data.DataLoader,如何将数据集(train/test)转换成两个NumPy数组,以便所有的示例都出现?在

注意:我暂时将批处理大小保留为默认值1;对于train,我可以将其设置为60000,对于test,我可以将其设置为10000,但是我不希望使用这种幻数。在

谢谢。在


Tags: 数据文档testnumpy目标datatrainpytorch
2条回答

此任务不需要使用torch.utils.data.DataLoader。在

from torchvision import datasets, transforms

train_set = datasets.MNIST('./data', train=True, download=True)
test_set = datasets.MNIST('./data', train=False, download=True)

train_set_array = train_set.data.numpy()
test_set_array = test_set.data.numpy()

注意,在这种情况下,目标被排除在外。在

如果我理解正确的话,您想将MNIST图像的整个序列数据集(总共60000个图像,每个图像的大小为1x28x28数组,颜色通道为1个)作为一个numpy数组(60000,1,28,28)?在

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Transform to normalized Tensors 
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))])

train_dataset = datasets.MNIST('./MNIST/', train=True, transform=transform, download=True)
# test_dataset = datasets.MNIST('./MNIST/', train=False, transform=transform, download=True)


train_loader = DataLoader(train_dataset, batch_size=len(train_dataset))
# test_loader = DataLoader(test_dataset, batch_size=len(test_dataset))

train_dataset_array = next(iter(train_loader))[0].numpy()
# test_dataset_array = next(iter(test_loader))[0].numpy()

结果是:

^{pr2}$

相关问题 更多 >