PyTorch Lightning:在检查点文件中包含一些张量对象

2024-04-20 05:29:22 发布

您现在位置:Python中文网/ 问答频道 /正文

由于Pytorch Lightning为模型检查点提供了自动保存功能,因此我使用它来保存top-k最佳模型。特别是在培训师设置方面

    checkpoint_callback = ModelCheckpoint(
                            monitor='val_acc',
                            dirpath='checkpoints/',
                            filename='{epoch:02d}-{val_acc:.2f}',
                            save_top_k=5,
                            mode='max',
                            )

这可以正常工作,但不会保存模型对象的某些属性。我的模型在每次训练结束时都会存储一些张量,这样

class SampleNet(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.save_hyperparameters()
        self.layer = torch.nn.Linear(100, 1)
        self.loss = torch.nn.CrossEntropy()

        self.some_data = None # Initialize as None

    def training_step(self, batch):
        x, t = batch
        out = self.layer(x)
        loss = self.loss(out, t)
        results = {'loss': loss}

        return results

    def training_epoch_end(self, outputs):
       
        self.some_data = some_tensor_object

这是一个简化的示例,但我希望上面checkpoint_callback生成的检查点文件记住属性self.some_data,但是当我从检查点加载模型时,它总是重置为None。我确认它在培训期间已成功更新

我试图在init中不将其初始化为None,但在加载模型时该属性将消失

我希望避免将属性保存为不同的pt文件,因为它与模型配置相关,因此我需要在以后手动将该文件与相应的检查点文件进行匹配

有可能在检查点文件中包含这样的张量属性吗


Tags: 文件模型selfnonedata属性inittop
2条回答

似乎不可能直接提取参数,因为最有可能使用^{}。 这种方法只提取实际视为参数的张量值。因此,在这种情况下,解决方法是将数据保存为参数(请参见docs):

self.some_data = torch.nn.parameter.Parameter(your_data)

只需将模型类挂钩on_save_checkpoint()on_load_checkpoint()用于所有要与默认属性一起保存的对象

def on_save_checkpoint(self, checkpoint) -> None:
    "Objects to include in checkpoint file"
    checkpoint["some_data"] = self.some_data

def on_load_checkpoint(self, checkpoint) -> None:
    "Objects to retrieve from checkpoint file"
    self.some_data= checkpoint["some_data"]

See module docs

相关问题 更多 >