如何调试ValueError: `FlatParameter`需要统一dtype,但得到了torch.float32和torch.bfloat16?
我正在尝试使用Pytorch Lightning Fabric进行分布式的FSDP训练,同时在LLAMA 2上进行Huggingface PEFT LORA的微调,但我的代码出现了错误,具体错误信息是:
`FlatParameter` requires uniform dtype but got torch.float32 and torch.bfloat16
File ".......", line 100, in <module>
model, optimizer = fabric.setup(model, optimizer)
ValueError: `FlatParameter` requires uniform dtype but got torch.float32 and torch.bfloat16
我该如何找出在pytorch fabric中哪些张量是float32类型的呢?