PyTorch torch.no_grad()与requires_grad=False

2024-04-20 06:04:37 发布

您现在位置:Python中文网/ 问答频道 /正文

我下面介绍一个PyTorch tutorial,它使用Huggingface Transformers库中的BERT NLP模型(特征提取器)。梯度更新有两段相互关联的代码,我不理解

(一)

本教程有一个类,其中forward()函数围绕对BERT特征提取器的调用创建一个torch.no_grad()块,如下所示:

bert = BertModel.from_pretrained('bert-base-uncased')

class BERTGRUSentiment(nn.Module):
    
    def __init__(self, bert):
        super().__init__()
        self.bert = bert
        
    def forward(self, text):
        with torch.no_grad():
            embedded = self.bert(text)[0]

(2)param.requires_grad = False

同一教程中还有另一部分冻结了BERT参数

for name, param in model.named_parameters():                
    if name.startswith('bert'):
        param.requires_grad = False

我什么时候需要(1)和/或(2)?

  • 如果我想训练一个冰冻的伯特,我需要同时启用这两个吗
  • 如果我想训练让BERT更新,我是否需要禁用这两个功能

此外,我运行了所有四种组合并发现:

   with torch.no_grad   requires_grad = False  Parameters  Ran
   ------------------   ---------------------  ----------  ---
a. Yes                  Yes                      3M        Successfully
b. Yes                  No                     112M        Successfully
c. No                   Yes                      3M        Successfully
d. No                   No                     112M        CUDA out of memory

有人能解释一下发生了什么事吗?为什么我得到的是(d)而不是(b)?两者都有112M可学习的参数


Tags: noselffalseparamdef教程torchyes
1条回答
网友
1楼 · 发布于 2024-04-20 06:04:37

这是一个较老的讨论,多年来略有变化(主要是因为with torch.no_grad()作为一种模式的目的。在on Stackoverflow already中可以找到一个很好的答案来回答您的问题。
然而,由于原来的问题有很大的不同,我将避免标记为重复,特别是由于关于记忆的第二部分

no_grad的初步解释如下here

with torch.no_grad() is a context manager and is used to prevent calculating gradients [...].

另一方面,使用requires_grad

to freeze part of your model and train the rest [...].

来源再次the SO post

本质上,使用requires_grad只会禁用网络的一部分,而no_grad根本不会存储任何梯度,因为您可能将其用于推理而不是训练。
要分析参数组合的行为,让我们调查发生了什么:

  • a)b)根本不存储任何渐变,这意味着无论参数有多少,您都有更多的可用内存,因为您没有为潜在的向后传递保留它们
  • c)必须为以后的反向传播存储前向传递,但是,只存储有限数量的参数(300万),这使得这仍然是可管理的
  • d)但是,需要为所有1.12亿参数存储前向传递,这会导致内存不足

相关问题 更多 >