迭代torch数据集未加载多个目标

0 投票
1 回答
33 浏览
提问于 2025-04-12 19:54

我正在尝试从文件中加载一个数据集,并在其上训练一个人工智能模型。出于某种原因,当我在主程序中使用 for iamges, targets in dataloader 时,它加载的目标数据是这样的:

{
    'image_id':[all image ids],
    'keypoints':[all keypoint lists],
    'labels':[all label lists],
    'boxes':[all bboxes]
}

而不是我期望的这样:

[
    {
        'image_id':image_id of first sample,
        'keypoints':[list of keypoints of first sample],
        'labels':[list of labels of first sample],
        'boxes':bbox of first sample
    },
...
    {
        'image_id':image_id of fourth sample,
        'keypoints':[list of keypoints of fourth sample],
        'labels':[list of labels of fourth sample],
        'boxes':bbox of fourth sample
    }
]

这是我的 __getitem__ 函数:

def __getitem__(self, idx):
    annotation = self.annotations[idx]
    image_id = annotation['image_id']
    file_name = annotation['file_name']
    image_path = f"{self.images_dir}/{file_name}"
    image = Image.open(image_path).convert("RGB")
    bbox = np.array(annotation['bbox'])
    keypoints = np.array([[ann["x"],ann["y"]] for ann in annotation["keypoints"]])
    labels = np.array([kp_num[ann["name"]] for ann in annotation["keypoints"]])
    target = {
        "image_id":image_id,
        "keypoints": torch.tensor(keypoints, dtype=torch.float32),
        "labels": torch.tensor(labels, dtype=torch.int64),
        "boxes":torch.tensor(bbox, dtype=torch.int)
    }
    
    if self.transform:
        image = self.transform(image)

    return image, target

我希望它返回一个目标列表,但它却返回了一个包含列表的字典。我尝试把目标放在返回语句中,放进一个只有一个元素的列表里,但结果却是返回了一个只有一个条目的列表,里面包含了所有信息,而不是一个包含多个目标的列表。我正在使用 torch.utils.data 的 DataLoader 类。

编辑:问题解决了,我实现了一个自定义的 collate 函数,像这样:

def custom_collate_fn(batch):
    images = [item[0] for item in batch]
    targets = [item[1] for item in batch] 
    return images, targets

1 个回答

0

默认的 collate_fn 就像是把你字典里的每个键的值堆叠起来。如果假设每批数据的大小是 N,那么 target['key'] 中的每个元素会变成 [N, ...]。(这样做是为了将来更好的兼容性、内存分配和优化数据集浏览。)

如果你真的需要从字典中提取出值,可以试试下面的代码。

def my_collate_fn(batch_sample):
    image = []
    target = []
    for sample in batch_sample:
        # The `sample` is returns of your __getitem__()
        image.append(sample[0])
        target.append(sample[1])

    return (torch.stack(image).contiguous(),
            target) # The `target` is no longer a tesnor as the dict cannot be stacked itself.
# --- on your later DataLoader init --- #
loader = torch.utils.data.DataLoader(collate_fn=my_collate_fn, **kwargs)

顺便提一下,基于以上原因,我建议你还是使用默认的 collate_fn

撰写回答