如何从头开始制作PyTorch数据加载器?

2024-06-17 19:45:02 发布

您现在位置:Python中文网/ 问答频道 /正文

是否可以从头开始重新创建PyTorch DataLoader的简单版本? 该类应该能够根据批大小返回当前批

例如,下面的代码只允许我一次返回一个示例

X = np.array([[1,2],[3,4],[5,6],[6,7]])

class DataLoader:
    def __init__(self, X, b_size):
        self.X = X
        self.b_size = b_size
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, index):
        return self.X[index]

但我想要实现的是,如果我指定b_size=2,它将返回:

Iteration 0: [[1,2],[3,4]]
Iteration 1: [[5,6],[7,8]]

在Python中有可能做类似的事情吗?我不能使用DataLoader类


Tags: 代码self版本示例sizeindexlenreturn
1条回答
网友
1楼 · 发布于 2024-06-17 19:45:02
X = np.array([[1,2],[3,4],[5,6],[6,7]])

class DataLoader:
    def __init__(self, X, b_size):
        self.X = X
        self.b_size = b_size
    
    def __len__(self):
        return len(self.X)//self.b_size
    
    def __getitem__(self, index):        
        return self.X[index*self.b_size:index*self.b_size+self.b_size]

d = DataLoader(X, 2)
for i in range(len(d)):
  print (f"Iteration {i}: {d[i]}")

输出:

Iteration 0: [[1 2]
 [3 4]]
Iteration 1: [[5 6]
 [6 7]]

相关问题 更多 >