目前,我正在尝试复制一个DeblurGanV2网络。目前,我正在进行培训。以下是我的培训管道的当前状态:
from torch.autograd import Variable
torch.autograd.set_detect_anomaly(mode=True)
total_generator_loss = 0
total_discriminator_loss = 0
psnr_score = 0.0
used_loss_function = 'wgan_gp_loss'
for epoch in range(n_epochs):
#set to train mode
generator.train(); discriminator.train()
tqdm_bar = tqdm(train_loader, desc=f'Training Epoch {epoch} ', total=int(len(train_loader)))
for batch_idx, imgs in enumerate(tqdm_bar):
#load imgs to cpu
blurred_images = imgs["blurred"].cuda()
sharped_images = imgs["sharp"].cuda()
# generator output
deblurred_img = generator(blurred_images)
# denormalize
with torch.no_grad():
denormalized_blurred = denormalize(blurred_images)
denormalized_sharp = denormalize(sharped_images)
denormalized_deblurred = denormalize(deblurred_img)
# get D's output
sharp_discriminator_out = discriminator(sharped_images)
deblurred_discriminator_out = discriminator(deblurred_img)
# set critic_updates
if used_loss_function== 'wgan_gp_loss':
critic_updates = 5
else:
critic_updates = 1
#train discriminator
discriminator_loss = 0
for i in range(critic_updates):
discriminator_optimizer.zero_grad()
# train discriminator on real and fake
if used_loss_function== 'wgan_gp_loss':
gp_lambda = 10
alpha = random.random()
interpolates = alpha * sharped_images + (1 - alpha) * deblurred_img
interpolates_discriminator_out = discriminator(interpolates)
kwargs = {'gp_lambda': gp_lambda,
'interpolates': interpolates,
'interpolates_discriminator_out': interpolates_discriminator_out,
'sharp_discriminator_out': sharp_discriminator_out,
'deblurred_discriminator_out': deblurred_discriminator_out
}
wgan_loss_d, gp_d = wgan_gp_loss('D', **kwargs)
discriminator_loss_per_update = wgan_loss_d + gp_d
discriminator_loss_per_update.backward(retain_graph=True)
discriminator_optimizer.step()
discriminator_loss += discriminator_loss_per_update.item()
但当我运行此代码时,我收到以下错误消息:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 512, 4, 4]] is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
RuntimeError Traceback (most recent call last) in () 62 # discriminator_loss_per_update = gan_loss_d 63 —> 64 discriminator_loss_per_update.backward(retain_graph=True) 65 discriminator_optimizer.step() 66 discriminator_loss += discriminator_loss_per_update.item()
1 frames /usr/local/lib/python3.7/dist-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph, inputs) 243 create_graph=create_graph, 244 inputs=inputs) → 245 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs) 246 247 def register_hook(self, hook):
/usr/local/lib/python3.7/dist-packages/torch/autograd/init.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs) 145 Variable.execution_engine.run_backward( 146 tensors, grad_tensors, retain_graph, create_graph, inputs, → 147 allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag 148 149
不幸的是,我无法真正跟踪导致此错误的就地操作。有没有人对我有什么想法或建议?如有任何意见,我将不胜感激:微微一笑:
尝试将最后一行更改为:
相关问题 更多 >
编程相关推荐