如何调试ValueError: `FlatParameter`需要统一dtype,但得到了torch.float32和torch.bfloat16?

0 投票
1 回答
49 浏览
提问于 2025-04-13 00:56

我正在尝试使用Pytorch Lightning Fabric进行分布式的FSDP训练,同时在LLAMA 2上进行Huggingface PEFT LORA的微调,但我的代码出现了错误,具体错误信息是:

`FlatParameter` requires uniform dtype but got torch.float32 and torch.bfloat16
  File ".......", line 100, in <module>
    model, optimizer = fabric.setup(model, optimizer)
ValueError: `FlatParameter` requires uniform dtype but got torch.float32 and torch.bfloat16

我该如何找出在pytorch fabric中哪些张量是float32类型的呢?

1 个回答

撰写回答