将Pytorch bfloat16张量转换为numpy时抛出TypeError
当你尝试把一个Torch的bfloat16张量转换成numpy数组时,会出现一个TypeError
错误:
import torch
x = torch.Tensor([0]).to(torch.bfloat16)
x.numpy() # TypeError: Got unsupported ScalarType BFloat16
import numpy as np
np.array(x) # same error
有没有什么办法可以解决这个转换问题呢?
1 个回答
0
目前,numpy
不支持 bfloat16格式。一个解决办法是先把张量从半精度转换为单精度,然后再进行转换:
x.float().numpy()
Pytorch 的维护者们 也在考虑在 Tensor.numpy
方法中添加一个 force=True
选项,这样就可以自动处理这个问题。