如何修复数据集以返回所需的输出(pytorch)

2024-03-28 20:00:57 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图使用来自外部函数的信息来决定返回哪些数据。在这里,我添加了一个简化的代码来演示这个问题。当我使用num_workers = 0时,我得到了所需的行为(3个时代之后的输出是18)。但是,当我增加num_workers的值时,每个历元之后的输出是相同的。全局变量保持不变

from torch.utils.data import Dataset, DataLoader

x = 6
def getx():
    global x
    x+=1
    print("x: ", x)
    return x

class MyDataset(Dataset):
    def __init__(self):
        pass

    def __getitem__(self, index):
        global x
        x = getx()
        return x
    
    def __len__(self):
        return 3

dataset = MyDataset()
loader = DataLoader(
    dataset,
    num_workers=0,
    shuffle=False
)

for epoch in range(4):
    for idx, data in enumerate(loader):
        print('Epoch {}, idx {}, val: {}'.format(epoch, idx, data))

num_workers=0如预期的那样为18时的最终输出。但是当num_workers>0时,x保持不变(最终输出为6)

如何使用num_workers>0获得与num_workers=0类似的行为(即,如何确保dataloader的__getitem__函数更改全局变量x的值)


Tags: 函数selfdatareturndefglobaldatasetnum
1条回答
网友
1楼 · 发布于 2024-03-28 20:00:57

其原因是python中多处理的潜在性质。设置num_workers意味着您的DataLoader创建了那个数量的子进程。每个子进程实际上是一个单独的python实例,具有自己的全局状态,并且不知道其他进程中发生了什么

python的多处理中的一个典型解决方案是使用Manager。但是,由于您的多处理是通过DataLoader提供的,因此您没有办法在这方面工作

幸运的是,还可以做些别的事情DataLoader实际上依赖于torch.multiprocessing,这反过来允许进程之间共享张量,只要它们在共享内存中

所以你能做的就是,简单地用x作为共享张量

from torch.utils.data import Dataset, DataLoader
import torch 

x = torch.tensor([6])
x.share_memory_()

def getx():
    global x
    x+=1
    print("x: ", x.item())
    return x

class MyDataset(Dataset):
    def __init__(self):
        pass

    def __getitem__(self, index):
        global x
        x = getx()
        return x
    
    def __len__(self):
        return 3

dataset = MyDataset()
loader = DataLoader(
    dataset,
    num_workers=2,
    shuffle=False
)

for epoch in range(4):
    for idx, data in enumerate(loader):
        print('Epoch {}, idx {}, val: {}'.format(epoch, idx, data))

输出:

x:  7
x:  8
x:  9
Epoch 0, idx 0, val: tensor([[7]])
Epoch 0, idx 1, val: tensor([[8]])
Epoch 0, idx 2, val: tensor([[9]])
x:  10
x:  11
x:  12
Epoch 1, idx 0, val: tensor([[10]])
Epoch 1, idx 1, val: tensor([[12]])
Epoch 1, idx 2, val: tensor([[12]])
x:  13
x:  14
x:  15
Epoch 2, idx 0, val: tensor([[13]])
Epoch 2, idx 1, val: tensor([[15]])
Epoch 2, idx 2, val: tensor([[14]])
x:  16
x:  17
x:  18
Epoch 3, idx 0, val: tensor([[16]])
Epoch 3, idx 1, val: tensor([[18]])
Epoch 3, idx 2, val: tensor([[17]])

虽然这样做有效,但并不完美。看看历元1,注意这里有2个12,而不是11和12。这意味着两个单独的进程在执行print之前执行了x+=1行。这是不可避免的,因为并行进程正在共享内存上工作

如果您熟悉操作系统的概念,您可能能够进一步实现某种semaphore,并根据需要使用一个额外的变量来控制对x的访问——但由于这超出了问题的范围,我将不作进一步阐述

相关问题 更多 >