将Pytorch bfloat16张量转换为numpy时抛出TypeError

0 投票
1 回答
303 浏览
提问于 2025-04-14 17:25

当你尝试把一个Torch的bfloat16张量转换成numpy数组时,会出现一个TypeError错误:

import torch

x = torch.Tensor([0]).to(torch.bfloat16)
x.numpy()  # TypeError: Got unsupported ScalarType BFloat16

import numpy as np
np.array(x)  # same error

有没有什么办法可以解决这个转换问题呢?

1 个回答

0

目前,numpy 不支持 bfloat16格式。一个解决办法是先把张量从半精度转换为单精度,然后再进行转换:

x.float().numpy()

Pytorch 的维护者们 也在考虑Tensor.numpy 方法中添加一个 force=True 选项,这样就可以自动处理这个问题。

** 不过,这个情况可能会改变,因为有 @jakevdp 的 相关工作

撰写回答