PyTorch torchvision数据集下载速度非常慢

2024-06-16 09:23:41 发布

您现在位置:Python中文网/ 问答频道 /正文

我在从torchvision下载EMNIST数据集的colab笔记本中有以下代码块。有时我会随机地得到一个错误,说

connectionError: HTTPConnectionPool(host='www.itl.nist.gov', port=80): Max retries exceeded with url: /iaui/vip/cs_links/EMNIST/gzip.zip (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7f701d0936d0>: Failed to establish a new connection: [Errno 110] Connection timed out'))

因此,我制作了一个代码块,其中包含一个函数,该函数在尝试下载失败时会自动调用(请参阅本文底部)。有时下载确实开始工作,但速度非常慢。请参见下面的进度条截图

enter image description here 下载时间超过两小时<;屏幕截图中URL的1GB数据。将数据集直接下载到我的机器大约需要60秒,因此提供数据的服务器没有问题。这似乎是colab与服务器的互联网连接或PyTorch处理数据下载的方式的问题。我真的不知道该怎么解决这个问题。我多次尝试重新连接运行时,但同样的问题发生了

数据下载代码:

from torchvision import transforms, datasets

train_data = None
test_data = None
def load_data():
  global train_data, test_data
  try:
    train_data = datasets.EMNIST("./data", split="balanced", train=True, download=True,
                                transform=transforms.Compose([
                                                              transforms.ToTensor()
                                ]))

    test_data = datasets.EMNIST("./data", split="balanced", train=False, download=True,
                                transform=transforms.Compose([
                                                              transforms.ToTensor()
                                ]))
  except:
    load_data()

load_data()
print(train_data)
print(test_data)

Tags: 数据函数代码test服务器truedataload