我对Pytorch DataLoader有问题。每当我试图通过iter()函数加载下一批时,该函数将无限期运行。我也尝试过在GoogleColab中运行这个函数,它返回一个递归错误。以下是迭代函数:
def create_data_loader(df, tokenizer, max_len, batch_size):
ds = PostTitleDataset(
title = df["clean_title"].to_numpy(),
label = df["6_way_label"].to_numpy(),
tokenizer = tokenizer,
max_len = max_len
)
return DataLoader(
ds,
batch_size = batch_size,
num_workers = 0
)
BATCH_SIZE = 16
MAX_LEN = 80
test_data_loader = create_data_loader(df_test, tokenizer, MAX_LEN, BATCH_SIZE)
**# This function runs indefinitely or gives me a RecursionError in Google Colab**
if __name__ == '__main__':
data = next(iter(test_data_loader))
data.keys()
下面是来自Google Colab的错误消息。在Jupyter笔记本中,它只是无限期运行,没有错误消息:
RecursionErrorTraceback (most recent call last) in () ----> 1 data = next(iter(test_data_loader)) 2 data.keys()
5 frames ... last 1 frames repeated, from the frame below ...
in getitem(self, item) 24 #batch = convert_to_batch(dataframe) 25 ---> 26 title = self["clean_title"][item] 27 label = self["6_way_label"][item] 28
RecursionError: maximum recursion depth exceeded
有人知道如何解决这个问题,以及如何让iter()函数成功地返回批处理吗?
为了提供完整信息,您可以在下面找到我的自定义数据集类:
class PostTitleDataset(Dataset):
def __init__(self, title, label, tokenizer, max_len):
self.title = title,
self.label = label,
self. tokenizer = tokenizer,
self.max_len = max_len
def __len__(self):
return len(self.title)
def __getitem__(self, item):
#batch = convert_to_batch(dataframe)
title = self["clean_title"][item]
label = self["6_way_label"][item]
# Encode text input content
encoding = self.tokenizer(
title,
padding=True,
truncation=True,
add_special_tokens=True,
return_token_type_ids=False,
return_attention_mask=True,
return_tensors="pt",
)
# Return first post title of batch as validation
return {
"clean title": title,
"input_ids": encoding["input_ids"].flatten(),
"attention_mask": encoding["attention_mask"].flatten(),
"Label": torch.tensor(label, dtype=torch.long)
}
目前没有回答
相关问题 更多 >
编程相关推荐