Pythorch训练时,记忆不断累积

2024-06-09 03:48:26 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在用Pythorch训练一个深度学习模型。由于未知的原因,内存不断积累,导致会话在30个时代以下被杀死和不适合。在

这里有一些想法:

  1. 不知道是不是由matplotlib引起的,所以我添加了plt.close('all');没用

  2. 添加了gc.collect();不起作用

  3. 不知道是不是由cv2.imwrite()引起的,但不知道如何检查。有什么建议吗?

  4. Pythorch问题?

  5. 其他。。。在

    model.train()
    for epo in range(epoch):
        for i, data in enumerate(trainloader, 0):
            inputs = data
            inputs = Variable(inputs)
            optimizer.zero_grad()
    
            top = model.upward(inputs + white(inputs))
            outputs = model.downward(top, shortcut = True)
    
    
            loss = criterion(inputs, outputs)
            loss.backward()
            optimizer.step()
    
            # Print generated pictures every 100 iters
            if i % 100 == 0:
                inn = inputs[0].view(128, 128).detach().numpy() * 255
                cv2.imwrite("/home/tk/Documents/recover/" + str(epo) + "_" + str(i) + ".png", inn)
    
                out = outputs[0].view(128, 128).detach().numpy() * 255
                cv2.imwrite("/home/tk/Documents/recover/" + str(epo) + "_" + str(i) + "_re.png", out)
    
            # Print loss every 50 iters
            if i % 50 == 0:
                print ('[%d, %5d] loss: %.3f' % (epo, i, loss.item()))
    
        gc.collect()
        plt.close("all")
    

===========================================================================

20181222更新

数据集和DalaLoader

^{pr2}$

这个模型真的很复杂很长…再加上以前没有发生过内存积累的问题(使用同一个模型),所以我不在这里发表。。。在


Tags: 内存模型closemodelpltalloutputscv2