我是新来的。在我开始接受CNN培训之前,我一直在努力学习如何查看输入图像。我很难将图像转换为可与matplotlib一起使用的形式。在
到目前为止,我已经试过了:
from multiprocessing import freeze_support
import torch
from torch import nn
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader, Sampler
from torchvision import datasets
from torchvision.transforms import transforms
from torch.optim import Adam
import matplotlib.pyplot as plt
import numpy as np
import PIL
num_classes = 5
batch_size = 100
num_of_workers = 5
DATA_PATH_TRAIN = 'C:\\Users\Aeryes\PycharmProjects\simplecnn\images\\train'
DATA_PATH_TEST = 'C:\\Users\Aeryes\PycharmProjects\simplecnn\images\\test'
trans = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.Resize(32),
transforms.CenterCrop(32),
transforms.ToPImage(),
transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))
])
train_dataset = datasets.ImageFolder(root=DATA_PATH_TRAIN, transform=trans)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_of_workers)
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
print(npimg)
plt.imshow(np.transpose(npimg, (1, 2, 0, 1)))
def main():
# get some random training images
dataiter = iter(train_loader)
images, labels = dataiter.next()
# show images
imshow(images)
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
if __name__ == "__main__":
main()
但是,这会引发错误:
^{pr2}$我试着打印出数组来得到尺寸,但我不知道这是怎么回事。这很令人困惑。在
我的直接问题是:在使用DataLoader对象中的张量进行训练之前,如何查看输入图像?在
首先,
dataloader
输出4维张量-[batch, channel, height, width]
。Matplotlib和其他图像处理库通常需要[height, width, channel]
。你使用转置是对的,只是方式不对。在在您的
images
中会有很多图像,所以首先您需要选择一个(或编写一个for循环来保存所有图像)。这将是简单的images[i]
,通常我使用i=0
。在然后,你的转置应该把现在
[channel, height, width]
张量转换成[height, width, channel]
张量。为此,请使用np.transpose(image.numpy(), (1, 2, 0))
,非常像您的。在把它们放在一起,你应该
有时您需要调用}(将数据从GPU传输到CPU),这将取决于用例
^{pr2}$.detach()
(将这部分从计算图中分离)和{相关问题 更多 >
编程相关推荐