意外的打印信息干扰PyTorch训练中的tqdm进度条

1 投票
1 回答
65 浏览
提问于 2025-04-13 16:02

我正在尝试理解如何使用 tqdm 来实现进度条。我的代码大致如下:

import torch
import torchvision
print(f"torch version: {torch.__version__}")
print(f"torchvision version: {torchvision.__version__}")

load_data()
manual_transforms = transforms.Compose([])
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders()

# them within the main function I have placed the train function that exists in the `engine.py` file
def main():

      results = engine.train(model=model,
        train_dataloader=train_dataloader,
        test_dataloader=test_dataloader,
        optimizer=optimizer,
        loss_fn=loss_fn,
        epochs=5,
        device=device)

engine.train() 函数中,有一段代码是 for epoch in tqdm(range(epochs)):,这样可以在训练每个批次时显示训练的进度。每次 tqdm 执行时,它还会打印出以下内容:

print(f"torch version: {torch.__version__}")
print(f"torchvision version: {torchvision.__version__}")

所以,最后我想问的是,为什么会这样?主函数是如何访问这些全局语句的?我该如何避免在每次循环中打印所有内容呢?

1 个回答

1

你所注意到的情况其实和 tqdm 没有关系,而是和 PyTorch 的内部工作原理有关,特别是 DataLoadernum_workers 属性,以及 Python 的多进程处理框架。下面是一个简单的示例,可以重现你遇到的问题:

from contextlib import suppress
from multiprocessing import set_start_method
import torch
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
print("torch version:", torch.__version__)

class DummyData(Dataset):
    def __len__(self): return 256
    def __getitem__(self, i): return i

def main():
    for batch in tqdm(DataLoader(DummyData(), batch_size=16, num_workers=4)):
        pass  # Do something
    
if __name__ == "__main__":
    # Enforce "spawn" method (e.g. on Linux) for subprocess creation to
    # reproduce problem (suppress error for reruns in same interpreter)
    with suppress(RuntimeError): set_start_method("spawn")
    main()

如果你运行这段代码,你会发现你的 PyTorch 版本号会被打印四次,这样就会搞乱你的 tqdm 进度条。这个数字和 num_workers 的值是一样的,这并不是巧合(你可以通过改变这个数字来轻松验证)。

发生的事情是这样的:

  • 如果 num_workers 大于 0,就会为工作进程启动子进程。
  • 在 Windows 和 macOS 上,这些子进程默认使用“spawn”方法启动(在 Linux 上,可以强制使用这种方法来重现你的观察,我用 set_start_method() 做了这个)。
  • “spawn”方法会为每个子进程启动你的主脚本,执行所有没有被 if __name__ == "__main__": 保护的代码。这包括你在脚本顶部的 print() 调用。

这种行为在 这里 有详细说明,还有一些可能的解决办法。我想对你有效的办法是:

把大部分主脚本的代码放在 if __name__ == '__main__': 这个块里,以确保它不会再次运行。

所以,你可以选择:

  1. print() 调用移动到 if __name__ == '__main__': 块的开头,
  2. print() 调用移动到 main() 函数的开头,或者
  3. 直接去掉 print() 调用。

另外,你也可以选择设置 num_workers=0,这样就完全禁用了多进程的使用(但这样你也会失去并行处理的好处)。请注意,你可能还需要把其他函数调用(比如 load_data())放到 if __name__ == '__main__': 块里或者 main() 函数里,以避免多次意外执行。

撰写回答