我试图使用来自外部函数的信息来决定返回哪些数据。在这里,我添加了一个简化的代码来演示这个问题。当我使用num_workers = 0
时,我得到了所需的行为(3个时代之后的输出是18)。但是,当我增加num_workers
的值时,每个历元之后的输出是相同的。全局变量保持不变
from torch.utils.data import Dataset, DataLoader
x = 6
def getx():
global x
x+=1
print("x: ", x)
return x
class MyDataset(Dataset):
def __init__(self):
pass
def __getitem__(self, index):
global x
x = getx()
return x
def __len__(self):
return 3
dataset = MyDataset()
loader = DataLoader(
dataset,
num_workers=0,
shuffle=False
)
for epoch in range(4):
for idx, data in enumerate(loader):
print('Epoch {}, idx {}, val: {}'.format(epoch, idx, data))
当num_workers=0
如预期的那样为18时的最终输出。但是当num_workers>0
时,x保持不变(最终输出为6)
如何使用num_workers>0
获得与num_workers=0
类似的行为(即,如何确保dataloader的__getitem__
函数更改全局变量x
的值)
其原因是python中多处理的潜在性质。设置
num_workers
意味着您的DataLoader
创建了那个数量的子进程。每个子进程实际上是一个单独的python实例,具有自己的全局状态,并且不知道其他进程中发生了什么python的多处理中的一个典型解决方案是使用
Manager
。但是,由于您的多处理是通过DataLoader提供的,因此您没有办法在这方面工作幸运的是,还可以做些别的事情
DataLoader
实际上依赖于torch.multiprocessing,这反过来允许进程之间共享张量,只要它们在共享内存中所以你能做的就是,简单地用x作为共享张量
输出:
虽然这样做有效,但并不完美。看看历元1,注意这里有2个12,而不是11和12。这意味着两个单独的进程在执行print之前执行了
x+=1
行。这是不可避免的,因为并行进程正在共享内存上工作如果您熟悉操作系统的概念,您可能能够进一步实现某种semaphore,并根据需要使用一个额外的变量来控制对x的访问——但由于这超出了问题的范围,我将不作进一步阐述
相关问题 更多 >
编程相关推荐