如何将Pythorch数据加载器转换为numpy数组以使用matplotlib显示图像数据?

2024-03-28 14:03:17 发布

您现在位置:Python中文网/ 问答频道 /正文

我是新来的。在我开始接受CNN培训之前,我一直在努力学习如何查看输入图像。我很难将图像转换为可与matplotlib一起使用的形式。在

到目前为止,我已经试过了:

from multiprocessing import freeze_support

import torch
from torch import nn
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader, Sampler
from torchvision import datasets
from torchvision.transforms import transforms
from torch.optim import Adam

import matplotlib.pyplot as plt
import numpy as np
import PIL

num_classes = 5
batch_size = 100
num_of_workers = 5

DATA_PATH_TRAIN = 'C:\\Users\Aeryes\PycharmProjects\simplecnn\images\\train'
DATA_PATH_TEST = 'C:\\Users\Aeryes\PycharmProjects\simplecnn\images\\test'

trans = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.Resize(32),
    transforms.CenterCrop(32),
    transforms.ToPImage(),
    transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))
    ])

train_dataset = datasets.ImageFolder(root=DATA_PATH_TRAIN, transform=trans)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_of_workers)

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    print(npimg)
    plt.imshow(np.transpose(npimg, (1, 2, 0, 1)))

def main():
    # get some random training images
    dataiter = iter(train_loader)
    images, labels = dataiter.next()

    # show images
    imshow(images)
    # print labels
    print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

if __name__ == "__main__":
    main()

但是,这会引发错误:

^{pr2}$

我试着打印出数组来得到尺寸,但我不知道这是怎么回事。这很令人困惑。在

我的直接问题是:在使用DataLoader对象中的张量进行训练之前,如何查看输入图像?在


Tags: from图像importimgdatasizebatchtrain
1条回答
网友
1楼 · 发布于 2024-03-28 14:03:17

首先,dataloader输出4维张量-[batch, channel, height, width]。Matplotlib和其他图像处理库通常需要[height, width, channel]。你使用转置是对的,只是方式不对。在

在您的images中会有很多图像,所以首先您需要选择一个(或编写一个for循环来保存所有图像)。这将是简单的images[i],通常我使用i=0。在

然后,你的转置应该把现在[channel, height, width]张量转换成[height, width, channel]张量。为此,请使用np.transpose(image.numpy(), (1, 2, 0)),非常像您的。在

把它们放在一起,你应该

plt.imshow(np.transpose(images[0].numpy(), (1, 2, 0)))

有时您需要调用.detach()(将这部分从计算图中分离)和{}(将数据从GPU传输到CPU),这将取决于用例

^{pr2}$

相关问题 更多 >