迭代torch数据集未加载多个目标
我正在尝试从文件中加载一个数据集,并在其上训练一个人工智能模型。出于某种原因,当我在主程序中使用 for iamges, targets in dataloader
时,它加载的目标数据是这样的:
{
'image_id':[all image ids],
'keypoints':[all keypoint lists],
'labels':[all label lists],
'boxes':[all bboxes]
}
而不是我期望的这样:
[
{
'image_id':image_id of first sample,
'keypoints':[list of keypoints of first sample],
'labels':[list of labels of first sample],
'boxes':bbox of first sample
},
...
{
'image_id':image_id of fourth sample,
'keypoints':[list of keypoints of fourth sample],
'labels':[list of labels of fourth sample],
'boxes':bbox of fourth sample
}
]
这是我的 __getitem__
函数:
def __getitem__(self, idx):
annotation = self.annotations[idx]
image_id = annotation['image_id']
file_name = annotation['file_name']
image_path = f"{self.images_dir}/{file_name}"
image = Image.open(image_path).convert("RGB")
bbox = np.array(annotation['bbox'])
keypoints = np.array([[ann["x"],ann["y"]] for ann in annotation["keypoints"]])
labels = np.array([kp_num[ann["name"]] for ann in annotation["keypoints"]])
target = {
"image_id":image_id,
"keypoints": torch.tensor(keypoints, dtype=torch.float32),
"labels": torch.tensor(labels, dtype=torch.int64),
"boxes":torch.tensor(bbox, dtype=torch.int)
}
if self.transform:
image = self.transform(image)
return image, target
我希望它返回一个目标列表,但它却返回了一个包含列表的字典。我尝试把目标放在返回语句中,放进一个只有一个元素的列表里,但结果却是返回了一个只有一个条目的列表,里面包含了所有信息,而不是一个包含多个目标的列表。我正在使用 torch.utils.data 的 DataLoader 类。
编辑:问题解决了,我实现了一个自定义的 collate 函数,像这样:
def custom_collate_fn(batch):
images = [item[0] for item in batch]
targets = [item[1] for item in batch]
return images, targets
1 个回答
0
默认的 collate_fn
就像是把你字典里的每个键的值堆叠起来。如果假设每批数据的大小是 N,那么 target['key']
中的每个元素会变成 [N, ...]。(这样做是为了将来更好的兼容性、内存分配和优化数据集浏览。)
如果你真的需要从字典中提取出值,可以试试下面的代码。
def my_collate_fn(batch_sample):
image = []
target = []
for sample in batch_sample:
# The `sample` is returns of your __getitem__()
image.append(sample[0])
target.append(sample[1])
return (torch.stack(image).contiguous(),
target) # The `target` is no longer a tesnor as the dict cannot be stacked itself.
# --- on your later DataLoader init --- #
loader = torch.utils.data.DataLoader(collate_fn=my_collate_fn, **kwargs)
顺便提一下,基于以上原因,我建议你还是使用默认的 collate_fn
。