如何获取我放入Pytorch中Dataloader的图像的文件名

2024-05-21 05:02:31 发布

您现在位置:Python中文网/ 问答频道 /正文

我使用Pytork加载如下图像:

inf_data = InfDataloader(img_folder=args.imgs_folder, target_size=args.img_size)
inf_dataloader = DataLoader(inf_data, batch_size=1, shuffle=True, num_workers=2)

然后:

    with torch.no_grad():
        for batch_idx, (img_np, img_tor) in enumerate(inf_dataloader, start=1):

            img_tor = img_tor.to(device)
            pred_masks, _ = model(img_tor)

但是我想得到图像的文件名。有人能帮我吗? 非常感谢


Tags: 图像targetimgdatasizebatchargsfolder
1条回答
网友
1楼 · 发布于 2024-05-21 05:02:31

DataLoader基本上无法获取文件名。但是在{}中,也就是上面提到的问题中的{},你可以从张量中得到文件名

class InfDataloader(Dataset):
    """
    Dataloader for Inference.
    """
    def __init__(self, img_folder, target_size=256):
        self.imgs_folder = img_folder

        self.img_paths = []

        img_path = self.imgs_folder + '/'
        img_list = os.listdir(img_path)
        img_list.sort()
        img_list.sort(key=lambda x: int(x[:-4]))  ##文件名按数字排序
        img_nums = len(img_list)
        for i in range(img_nums):
            img_name = img_path + img_list[i]
            self.img_paths.append(img_name)

        # self.img_paths = sorted(glob.glob(self.imgs_folder + '/*'))

        print(self.img_paths)


        self.target_size = target_size
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                              std=[0.229, 0.224, 0.225])

    def __getitem__(self, idx):
        """
        __getitem__ for inference
        :param idx: Index of the image
        :return: img_np is a numpy RGB-image of shape H x W x C with pixel values in range 0-255.
        And img_tor is a torch tensor, RGB, C x H x W in shape and normalized.
        """
        img = cv2.imread(self.img_paths[idx])
        name = self.img_paths[idx]

        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        # Pad images to target size
        img_np = pad_resize_image(img, None, self.target_size)
        img_tor = img_np.astype(np.float32)
        img_tor = img_tor / 255.0
        img_tor = np.transpose(img_tor, axes=(2, 0, 1))
        img_tor = torch.from_numpy(img_tor).float()
        img_tor = self.normalize(img_tor)

        return img_np, img_tor, name

我在这里加上一行 name = self.img_paths[idx] 然后把它还给我

所以

 with torch.no_grad():
        for batch_idx, (img_np, img_tor, name) in enumerate(inf_dataloader, start=1):
            img_tor = img_tor.to(device)
            pred_masks, _ = model(img_tor)

我可以知道名字

相关问题 更多 >