由于Pytorch Lightning为模型检查点提供了自动保存功能,因此我使用它来保存top-k最佳模型。特别是在培训师设置方面
checkpoint_callback = ModelCheckpoint(
monitor='val_acc',
dirpath='checkpoints/',
filename='{epoch:02d}-{val_acc:.2f}',
save_top_k=5,
mode='max',
)
这可以正常工作,但不会保存模型对象的某些属性。我的模型在每次训练结束时都会存储一些张量,这样
class SampleNet(pl.LightningModule):
def __init__(self):
super().__init__()
self.save_hyperparameters()
self.layer = torch.nn.Linear(100, 1)
self.loss = torch.nn.CrossEntropy()
self.some_data = None # Initialize as None
def training_step(self, batch):
x, t = batch
out = self.layer(x)
loss = self.loss(out, t)
results = {'loss': loss}
return results
def training_epoch_end(self, outputs):
self.some_data = some_tensor_object
这是一个简化的示例,但我希望上面checkpoint_callback
生成的检查点文件记住属性self.some_data
,但是当我从检查点加载模型时,它总是重置为None
。我确认它在培训期间已成功更新
我试图在init
中不将其初始化为None,但在加载模型时该属性将消失
我希望避免将属性保存为不同的pt
文件,因为它与模型配置相关,因此我需要在以后手动将该文件与相应的检查点文件进行匹配
有可能在检查点文件中包含这样的张量属性吗
似乎不可能直接提取参数,因为最有可能使用^{} 。
这种方法只提取实际视为参数的张量值。因此,在这种情况下,解决方法是将数据保存为参数(请参见docs):
只需将模型类挂钩
on_save_checkpoint()
和on_load_checkpoint()
用于所有要与默认属性一起保存的对象See module docs
相关问题 更多 >
编程相关推荐