意外的打印信息干扰PyTorch训练中的tqdm进度条
我正在尝试理解如何使用 tqdm
来实现进度条。我的代码大致如下:
import torch
import torchvision
print(f"torch version: {torch.__version__}")
print(f"torchvision version: {torchvision.__version__}")
load_data()
manual_transforms = transforms.Compose([])
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders()
# them within the main function I have placed the train function that exists in the `engine.py` file
def main():
results = engine.train(model=model,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
optimizer=optimizer,
loss_fn=loss_fn,
epochs=5,
device=device)
在 engine.train()
函数中,有一段代码是 for epoch in tqdm(range(epochs)):
,这样可以在训练每个批次时显示训练的进度。每次 tqdm
执行时,它还会打印出以下内容:
print(f"torch version: {torch.__version__}")
print(f"torchvision version: {torchvision.__version__}")
所以,最后我想问的是,为什么会这样?主函数是如何访问这些全局语句的?我该如何避免在每次循环中打印所有内容呢?
1 个回答
1
你所注意到的情况其实和 tqdm
没有关系,而是和 PyTorch 的内部工作原理有关,特别是 DataLoader
的 num_workers
属性,以及 Python 的多进程处理框架。下面是一个简单的示例,可以重现你遇到的问题:
from contextlib import suppress
from multiprocessing import set_start_method
import torch
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
print("torch version:", torch.__version__)
class DummyData(Dataset):
def __len__(self): return 256
def __getitem__(self, i): return i
def main():
for batch in tqdm(DataLoader(DummyData(), batch_size=16, num_workers=4)):
pass # Do something
if __name__ == "__main__":
# Enforce "spawn" method (e.g. on Linux) for subprocess creation to
# reproduce problem (suppress error for reruns in same interpreter)
with suppress(RuntimeError): set_start_method("spawn")
main()
如果你运行这段代码,你会发现你的 PyTorch 版本号会被打印四次,这样就会搞乱你的 tqdm
进度条。这个数字和 num_workers
的值是一样的,这并不是巧合(你可以通过改变这个数字来轻松验证)。
发生的事情是这样的:
- 如果
num_workers
大于 0,就会为工作进程启动子进程。 - 在 Windows 和 macOS 上,这些子进程默认使用“spawn”方法启动(在 Linux 上,可以强制使用这种方法来重现你的观察,我用
set_start_method()
做了这个)。 - “spawn”方法会为每个子进程启动你的主脚本,执行所有没有被
if __name__ == "__main__":
保护的代码。这包括你在脚本顶部的print()
调用。
这种行为在 这里 有详细说明,还有一些可能的解决办法。我想对你有效的办法是:
把大部分主脚本的代码放在
if __name__ == '__main__':
这个块里,以确保它不会再次运行。
所以,你可以选择:
- 把
print()
调用移动到if __name__ == '__main__':
块的开头, - 把
print()
调用移动到main()
函数的开头,或者 - 直接去掉
print()
调用。
另外,你也可以选择设置 num_workers=0
,这样就完全禁用了多进程的使用(但这样你也会失去并行处理的好处)。请注意,你可能还需要把其他函数调用(比如 load_data()
)放到 if __name__ == '__main__':
块里或者 main()
函数里,以避免多次意外执行。