我正在编写一个定制的pytorch数据集。在__init__
中,dataset对象加载包含特定数据的文件。但在我的程序中,我只希望访问部分数据(如果有帮助的话,实现训练/有效切割)。起初我认为这种行为是通过重写__len__
来控制的,但结果证明修改__len__
没有帮助。一个简单的例子如下:
from torch.utils.data import Dataset, DataLoader
import torch
class NewDS(Dataset):
def __init__(self):
self.data = torch.randn(10,2) # suppose there are 10 items in the data file
def __len__(self):
return len(self.data)-5 # But I only want to access the first 5 items
def __getitem__(self, index):
return self.data[index]
ds = NewDS()
for i, x in enumerate(ds):
print(i)
输出为0到9,而期望的行为为0到4
在这样的for循环中使用此dataset对象时,该对象如何知道枚举已结束?任何其他达到类似效果的方法都是受欢迎的
您正在使用
Dataset
类创建自定义数据加载器,同时使用for循环枚举它。这不是它的工作原理。对于枚举,必须将Dataset
传递给DataLoader
类。你的代码会像这样工作的很好更多详情可在本官方网站pytorch tutorial上阅读
您可以使用^{} 获取数据的子集
循环中的^{} 将返回项,直到它获得^{} 异常
输出:
相关问题 更多 >
编程相关推荐