无法修复:RuntimeError:梯度计算所需的变量之一已被就地操作修改

2024-05-14 22:42:16 发布

您现在位置:Python中文网/ 问答频道 /正文

目前,我正在尝试复制一个DeblurGanV2网络。目前,我正在进行培训。以下是我的培训管道的当前状态:

from torch.autograd import Variable
torch.autograd.set_detect_anomaly(mode=True)
total_generator_loss = 0
total_discriminator_loss = 0
psnr_score = 0.0
used_loss_function = 'wgan_gp_loss'
for epoch in range(n_epochs):

      #set to train mode
      generator.train(); discriminator.train()
      tqdm_bar = tqdm(train_loader, desc=f'Training Epoch {epoch} ', total=int(len(train_loader)))
      for batch_idx, imgs in enumerate(tqdm_bar):
        
        #load imgs to cpu
        blurred_images = imgs["blurred"].cuda()
        sharped_images = imgs["sharp"].cuda()
        
        # generator output
        deblurred_img = generator(blurred_images)
    
        # denormalize
        with torch.no_grad():
          denormalized_blurred = denormalize(blurred_images)
          denormalized_sharp = denormalize(sharped_images)
          denormalized_deblurred = denormalize(deblurred_img)
    
        # get D's output
        sharp_discriminator_out = discriminator(sharped_images)
        deblurred_discriminator_out = discriminator(deblurred_img)
    
        # set critic_updates
        if used_loss_function== 'wgan_gp_loss':
          critic_updates = 5
        else:
            critic_updates = 1
    
        #train discriminator
        discriminator_loss = 0
        for i in range(critic_updates):
          discriminator_optimizer.zero_grad()
          # train discriminator on real and fake
          if used_loss_function== 'wgan_gp_loss':
            gp_lambda = 10
            alpha = random.random()
            interpolates = alpha * sharped_images + (1 - alpha) * deblurred_img
            interpolates_discriminator_out = discriminator(interpolates)
            kwargs = {'gp_lambda': gp_lambda,
                       'interpolates': interpolates,
                       'interpolates_discriminator_out': interpolates_discriminator_out,
                       'sharp_discriminator_out': sharp_discriminator_out,
                       'deblurred_discriminator_out': deblurred_discriminator_out
                        }
            wgan_loss_d, gp_d = wgan_gp_loss('D', **kwargs)
            discriminator_loss_per_update = wgan_loss_d + gp_d
    
          discriminator_loss_per_update.backward(retain_graph=True)
          discriminator_optimizer.step()
          discriminator_loss += discriminator_loss_per_update.item()

但当我运行此代码时,我收到以下错误消息:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 512, 4, 4]] is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

RuntimeError Traceback (most recent call last) in () 62 # discriminator_loss_per_update = gan_loss_d 63 —> 64 discriminator_loss_per_update.backward(retain_graph=True) 65 discriminator_optimizer.step() 66 discriminator_loss += discriminator_loss_per_update.item()

1 frames /usr/local/lib/python3.7/dist-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph, inputs) 243 create_graph=create_graph, 244 inputs=inputs) → 245 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs) 246 247 def register_hook(self, hook):

/usr/local/lib/python3.7/dist-packages/torch/autograd/init.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs) 145 Variable.execution_engine.run_backward( 146 tensors, grad_tensors, retain_graph, create_graph, inputs, → 147 allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag 148 149

不幸的是,我无法真正跟踪导致此错误的就地操作。有没有人对我有什么想法或建议?如有任何意见,我将不胜感激:微微一笑:


Tags: intraintorchoutgraphgpinputsimages

热门问题