Pytorch iter()无休止地运行或抛出递归错误

2024-04-19 12:43:58 发布

您现在位置:Python中文网/ 问答频道 /正文

我对Pytorch DataLoader有问题。每当我试图通过iter()函数加载下一批时,该函数将无限期运行。我也尝试过在GoogleColab中运行这个函数,它返回一个递归错误。以下是迭代函数:

def create_data_loader(df, tokenizer, max_len, batch_size):
    ds = PostTitleDataset(
        title = df["clean_title"].to_numpy(),
        label = df["6_way_label"].to_numpy(),
        tokenizer = tokenizer,
        max_len = max_len
    )
    
    return DataLoader(
        ds,
        batch_size = batch_size,
        num_workers = 0
    )

BATCH_SIZE = 16
MAX_LEN = 80

test_data_loader = create_data_loader(df_test, tokenizer, MAX_LEN, BATCH_SIZE)

**# This function runs indefinitely or gives me a RecursionError in Google Colab**
if __name__ == '__main__':
  data = next(iter(test_data_loader))
  data.keys()

下面是来自Google Colab的错误消息。在Jupyter笔记本中,它只是无限期运行,没有错误消息:

RecursionErrorTraceback (most recent call last) in () ----> 1 data = next(iter(test_data_loader)) 2 data.keys()

5 frames ... last 1 frames repeated, from the frame below ...

in getitem(self, item) 24 #batch = convert_to_batch(dataframe) 25 ---> 26 title = self["clean_title"][item] 27 label = self["6_way_label"][item] 28

RecursionError: maximum recursion depth exceeded

有人知道如何解决这个问题,以及如何让iter()函数成功地返回批处理吗?

为了提供完整信息,您可以在下面找到我的自定义数据集类:

class PostTitleDataset(Dataset):
    
    def __init__(self, title, label, tokenizer, max_len):
        self.title = title,
        self.label = label,
        self. tokenizer = tokenizer,
        self.max_len = max_len
        
    def __len__(self):
        return len(self.title)
    
    def __getitem__(self, item):
        
        #batch = convert_to_batch(dataframe)
        
        title = self["clean_title"][item]
        label = self["6_way_label"][item]
        
        # Encode text input content
        encoding = self.tokenizer(
            title,
            padding=True,
            truncation=True,
            add_special_tokens=True,
            return_token_type_ids=False,
            return_attention_mask=True,
            return_tensors="pt",
        )
        
        # Return first post title of batch as validation
        return {
            "clean title": title,
            "input_ids": encoding["input_ids"].flatten(),
            "attention_mask": encoding["attention_mask"].flatten(),
            "Label": torch.tensor(label, dtype=torch.long)
        }