通过dataloader枚举时无法访问所有数据

2024-06-06 05:53:54 发布

您现在位置:Python中文网/ 问答频道 /正文

我定义了一个自定义Dataset和一个自定义Dataloader,我想使用for i,batch in enumerate(loader)访问所有批处理。但是这个for循环在每个历元中给了我不同的批数,并且它们都远远小于实际的批数(等于number_of_samples/batch_size

以下是我如何定义数据集和数据加载器:



class UsptoDataset(Dataset):
    def __init__(self, csv_file):
        df = pd.read_csv(csv_file)
        self.rea_trees = df['reactants_trees'].to_numpy()
        self.syn_trees = df['synthons_trees'].to_numpy()
        self.syn_smiles = df['synthons'].to_numpy()
        self.product_smiles = df['product'].to_numpy()

    def __len__(self):
        return len(self.rea_trees)

    def __getitem__(self, item):
        rea_tree = self.rea_trees[item]
        syn_tree = self.syn_trees[item]
        syn_smile = self.syn_smiles[item]
        pro_smile = self.product_smiles[item]
        # omit the snippet used to process the data here, which gives us the variables used in the return statement.
        return {'input_words': input_words,
                'input_chars': input_chars,
                'syn_tree_indices': syn_tree_indices,
                'syn_rule_nl_left': syn_rule_nl_left,
                'syn_rule_nl_right': syn_rule_nl_right,
                'rea_tree_indices': rea_tree_indices,
                'rea_rule_nl_left': rea_rule_nl_left,
                'rea_rule_nl_right': rea_rule_nl_right,
                'class_mask': class_mask,
                'query_paths': query_paths,
                'labels': labels,
                'parent_matrix': parent_matrix,
                'syn_parent_matrix': syn_parent_matrix,
                'path_lens': path_lens,
                'syn_path_lens': syn_path_lens}

    @staticmethod
    def collate_fn(batch):
        input_words = torch.tensor(np.stack([_['input_words'] for _ in batch], axis=0), dtype=torch.long)
        input_chars = torch.tensor(np.stack([_['input_chars'] for _ in batch], axis=0), dtype=torch.long)
        syn_tree_indices = torch.tensor(np.stack([_['syn_tree_indices'] for _ in batch], axis=0), dtype=torch.long)
        syn_rule_nl_left = torch.tensor(np.stack([_['syn_rule_nl_left'] for _ in batch], axis=0), dtype=torch.long)
        syn_rule_nl_right = torch.tensor(np.stack([_['syn_rule_nl_right'] for _ in batch], axis=0), dtype=torch.long)
        rea_tree_indices = torch.tensor(np.stack([_['rea_tree_indices'] for _ in batch], axis=0), dtype=torch.long)
        rea_rule_nl_left = torch.tensor(np.stack([_['rea_rule_nl_left'] for _ in batch], axis=0), dtype=torch.long)
        rea_rule_nl_right = torch.tensor(np.stack([_['rea_rule_nl_right'] for _ in batch], axis=0), dtype=torch.long)
        class_mask = torch.tensor(np.stack([_['class_mask'] for _ in batch], axis=0), dtype=torch.float32)
        query_paths = torch.tensor(np.stack([_['query_paths'] for _ in batch], axis=0), dtype=torch.long)
        labels = torch.tensor(np.stack([_['labels'] for _ in batch], axis=0), dtype=torch.long)
        parent_matrix = torch.tensor(np.stack([_['parent_matrix'] for _ in batch], axis=0), dtype=torch.float)
        syn_parent_matrix = torch.tensor(np.stack([_['syn_parent_matrix'] for _ in batch], axis=0), dtype=torch.float)
        path_lens = torch.tensor(np.stack([_['path_lens'] for _ in batch], axis=0), dtype=torch.long)
        syn_path_lens = torch.tensor(np.stack([_['syn_path_lens'] for _ in batch], axis=0), dtype=torch.long)

        return_dict = {'input_words': input_words,
                       'input_chars': input_chars,
                       'syn_tree_indices': syn_tree_indices,
                       'syn_rule_nl_left': syn_rule_nl_left,
                       'syn_rule_nl_right': syn_rule_nl_right,
                       'rea_tree_indices': rea_tree_indices,
                       'rea_rule_nl_left': rea_rule_nl_left,
                       'rea_rule_nl_right': rea_rule_nl_right,
                       'class_mask': class_mask,
                       'query_paths': query_paths,
                       'labels': labels,
                       'parent_matrix': parent_matrix,
                       'syn_parent_matrix': syn_parent_matrix,
                       'path_lens': path_lens,
                       'syn_path_lens': syn_path_lens}
        return return_dict


train_dataset=UsptoDataset("train_trees.csv")

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=1, collate_fn=UsptoDataset.collate_fn)         

当我按如下方式使用dataloader时,它会给我每个历元不同的批数:

epoch_steps = len(train_loader)
for e in range(epochs):
    for j, batch_data in enumerate(train_loader):
        step = e * epoch_steps + j

日志显示第一个历元只有5个批次,第二个历元有3个批次,第三个历元有5个批次,依此类推

 1 Config:
  2 Namespace(batch_size_per_gpu=4, epochs=400, eval_every_epoch=1, hidden_size=128, keep=10, log_every_step=1, lr=0.001, new_model=False, save_dir='saved_model/', workers=1)
  3 2021-01-06 15:33:17,909 - __main__ - WARNING - Checkpoints not found in dir saved_model/, creating a new model.
  4 2021-01-06 15:33:18,340 - __main__ - INFO - Step: 0, Loss: 5.4213, Rule acc: 0.1388
  5 2021-01-06 15:33:18,686 - __main__ - INFO - Step: 1, Loss: 4.884, Rule acc: 0.542
  6 2021-01-06 15:33:18,941 - __main__ - INFO - Step: 2, Loss: 4.6205, Rule acc: 0.6122
  7 2021-01-06 15:33:19,174 - __main__ - INFO - Step: 3, Loss: 4.4442, Rule acc: 0.61
  8 2021-01-06 15:33:19,424 - __main__ - INFO - Step: 4, Loss: 4.3033, Rule acc: 0.6211
  9 2021-01-06 15:33:20,684 - __main__ - INFO - Dev Loss: 3.5034, Dev Sample Acc: 0.0, Dev Rule Acc: 0.5970844200679234, in epoch 0
 10 2021-01-06 15:33:22,203 - __main__ - INFO - Test Loss: 3.4878, Test Sample Acc: 0.0, Test Rule Acc: 0.6470248053471247
 11 2021-01-06 15:33:22,394 - __main__ - INFO - Found better dev sample accuracy 0.0 in epoch 0
 12 2021-01-06 15:33:22,803 - __main__ - INFO - Step: 10002, Loss: 3.6232, Rule acc: 0.6555
 13 2021-01-06 15:33:23,046 - __main__ - INFO - Step: 10003, Loss: 3.53, Rule acc: 0.6442
 14 2021-01-06 15:33:23,286 - __main__ - INFO - Step: 10004, Loss: 3.4907, Rule acc: 0.6498
 15 2021-01-06 15:33:24,617 - __main__ - INFO - Dev Loss: 3.3081, Dev Sample Acc: 0.0, Dev Rule Acc: 0.5980878387178693, in epoch 1
 16 2021-01-06 15:33:26,215 - __main__ - INFO - Test Loss: 3.2859, Test Sample Acc: 0.0, Test Rule Acc: 0.6466992994149526
 17 2021-01-06 15:33:26,857 - __main__ - INFO - Step: 20004, Loss: 3.3965, Rule acc: 0.6493
 18 2021-01-06 15:33:27,093 - __main__ - INFO - Step: 20005, Loss: 3.3797, Rule acc: 0.6314
 19 2021-01-06 15:33:27,353 - __main__ - INFO - Step: 20006, Loss: 3.3959, Rule acc: 0.5727
 20 2021-01-06 15:33:27,609 - __main__ - INFO - Step: 20007, Loss: 3.3632, Rule acc: 0.6279
 21 2021-01-06 15:33:27,837 - __main__ - INFO - Step: 20008, Loss: 3.3331, Rule acc: 0.6158
 22 2021-01-06 15:33:29,122 - __main__ - INFO - Dev Loss: 3.0911, Dev Sample Acc: 0.0, Dev Rule Acc: 0.6016287207603455, in epoch 2
 23 2021-01-06 15:33:30,689 - __main__ - INFO - Test Loss: 3.0651, Test Sample Acc: 0.0, Test Rule Acc: 0.6531393428643545
 24 2021-01-06 15:33:32,143 - __main__ - INFO - Dev Loss: 3.0911, Dev Sample Acc: 0.0, Dev Rule Acc: 0.6016287207603455, in epoch 3
 25 2021-01-06 15:33:33,765 - __main__ - INFO - Test Loss: 3.0651, Test Sample Acc: 0.0, Test Rule Acc: 0.6531393428643545
 26 2021-01-06 15:33:34,359 - __main__ - INFO - Step: 40008, Loss: 3.108, Rule acc: 0.6816
 27 2021-01-06 15:33:34,604 - __main__ - INFO - Step: 40009, Loss: 3.0756, Rule acc: 0.6732
 28 2021-01-06 15:33:35,823 - __main__ - INFO - Dev Loss: 3.0419, Dev Sample Acc: 0.0, Dev Rule Acc: 0.613776079245976, in epoch 4

仅供参考,len(train_loader.dataset)batch_sizelen(train_loader)的值分别是40008410002,这正是我所期望的。因此,使用enumerate只会给我几个批,比如3510002是预期的),这是非常令人困惑的


Tags: ininfoformainnlnpbatchtorch
1条回答
网友
1楼 · 发布于 2024-06-06 05:53:54

我不确定你的代码有什么问题。据我所知,您在collate_fn中试图做的是从批处理中收集并堆叠相同特征类型的数据。比如:

您正在使用input_wordsinput_charssyn_tree_indicessyn_rule_nl_leftsyn_rule_nl_leftsyn_rule_nl_rightrea_tree_indicesrea_tree_indicesrea_rule_nl_leftrea_rule_nl_rightclass_maskquery_pathslabelssyn_parent_matrix、^和{}键。在我的示例中,我们将仅使用abcd保持简单

  • __getitem__将从数据集中返回单个数据点。在你和我们的例子中,这将是一个词汇表{'a': ..., 'b': ..., 'c': ..., 'd': ...}

  • collate_fn:返回数据时,是数据集和数据加载器之间的中间层。它获取批处理元素的列表(使用__getitem__逐个收集的元素)。您要返回的是经过整理的批次。将[{'a': ..., 'b': ..., 'c': ..., 'd': ...}, ...] 转换成 {'a': [...], 'b': [...], 'c': [...], 'd': [...]}的东西。其中'a'键将包含来自a功能的所有数据

现在您可能不知道的是,对于这种简单类型的排序,您实际上不需要collate_fn。我相信元组词汇表是由PyTorch数据加载程序自动处理的。这意味着如果您从__getitem__返回一个字典,您的数据加载器将自动按键进行排序

这里,还是我们最简单的例子:

class D(Dataset):
    def __init__(self):
        super(D, self).__init__()
        self.a = [1,11,111,1111,11111]
        self.b = [2,22,222,2222,22222]
        self.c = [3,33,333,3333,33333]
        self.d = [4,44,444,4444,44444]

    def __getitem__(self, i):
        return {
            'a': self.a[i],
            'b': self.b[i],
            'c': self.c[i],
            'd': self.d[i]
        }

    def __len__(self):
        return len(self.a)

正如您在下面的打印中所看到的,数据是通过键收集的

>>> ds = D()
>>> dl = DataLoader(ds, batch_size=2, shuffle=True)

>>> for i, x in enumerate(dl):
>>>    print(i, x)
0 {'a': tensor([11, 1111]), 'b': tensor([22, 2222]), 'c': tensor([33, 3333]), 'd': tensor([44, 4444])}
1 {'a': tensor([1, 11111]), 'b': tensor([2, 22222]), 'c': tensor([3, 33333]), 'd': tensor([4, 44444])}
2 {'a': tensor([111]), 'b': tensor([222]), 'c': tensor([333]), 'd': tensor([444])}

提供collate_fn参数将删除此自动校对

相关问题 更多 >