如何使用Accelerate在主进程中广播张量?
我想在主进程中进行一些计算,然后把结果传给其他进程。现在我的代码大概是这样的:
from accelerate.utils import broadcast
x = None
if accelerator.is_local_main_process:
x = <do_some_computation>
x = broadcast(x) # I have even tried moving this line out of the if block
print(x.shape)
但是我遇到了以下错误:
TypeError: Unsupported types (<class 'NoneType'>) passed to `_gpu_broadcast_one` . Only nested list/tuple/dicts of objects that are valid for `is_torch_tensor` s hould be passed.
这个错误的意思是,x
还是 None
,也就是说它并没有真正被传送出去。我该怎么解决这个问题呢?
1 个回答
0
这里说的是,x
不能是 None
。它必须是一个张量(tensor),而且这个张量的形状要和其他的相同,并且要在当前进程的正确设备上。我猜这是因为 broadcast
在内部会进行一个 copy_
操作。还有,空的张量也不行。为了避免这个问题,我直接创建了一个全是零的张量。
from accelerate.utils import broadcast
x = torch.zeros(*final_shape, device=accelerator.device)
if accelerator.is_local_main_process:
x = <do_some_computation>
x = broadcast(x)
print(x.shape)