RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

问题

在用pytorch生成对抗网络的时候,出现错误Runtime Error: one of the variables needed for gradient computation has been modified by an inplace operation,特记录排坑记录。

环境配置

windows10 2004
python 3.7.4
pytorch 1.7.0 + cpu

解决过程

  • 尝试一

这段错误代码看上去不难理解,意思为:计算梯度所需的某变量已被一就地操作修改。什么是就地操作呢,举个例子如x += 1就是典型的就地操作,可将其改为y = x + 1。但很遗憾,这样并没有解决我的问题,这种方法的介绍如下。
在网上搜了很多相关博客,大多原因如下:

由于0.4.0把Varible和Tensor融合为一个Tensor,inplace操作,之前对Varible能用,但现在对Tensor,就会出错了。

所以解决方案很简单:将所有inplace操作转换为非inplace操作。如将x += 1换为y = x + 1
仍然有一个问题,即如何找到inplace操作,这里提供一个小trick:分阶段调用y.backward(),若报错,则说明这之前有问题;反之则说明错误在该行之后。

  • 尝试二

在我的代码里根本就没有找到任何inplace操作,因此上面这种方法行不通。自己盯着代码,debug,啥也看不出来,好久......
忽然有了新idea。我的训练阶段的代码如下:

for epoch in range(1, epochs + 1):
    for idx, (lr, hr) in enumerate(traindata_loader):
        lrs = lr.to(device)
        hrs = hr.to(device)

        # update the discriminator
        netD.zero_grad()
        logits_fake = netD(netG(lrs).detach())
        logits_real = netD(hrs)
        # Label smoothing
        real = (torch.rand(logits_real.size()) * 0.25 + 0.85).clone().detach().to(device)
        fake = (torch.rand(logits_fake.size()) * 0.15).clone().detach().to(device)
        d_loss = bce(logits_real, real) + bce(logits_fake, fake)
        d_loss.backward(retain_graph=True)
        optimizerD.step()

        # update the generator
        netG.zero_grad()
        # !!!问题出错行
        g_loss = contentLoss(netG(lrs), hrs) + adversarialLoss(logits_fake)
        g_loss.backward()        
        optimizerG.step()

判别器loss的backward是正常的,生成器loss的backward有问题。观察到g_loss由两项组成,所以很自然的想法就是删掉其中一项看是否正常。结果为:只保留第一项程序正常运行;g_loss中包含第二项程序就出错。
因此去看了adversarialLoss的代码:

class AdversarialLoss(nn.Module):
    def __init__(self):
        super(AdversarialLoss, self).__init__()
        self.bec_loss = nn.BCELoss()

    def forward(self, logits_fake):
        # Adversarial Loss
        # !!! 问题在这,logits_fake加上detach后就可以正常运行
        adversarial_loss = self.bec_loss(logits_fake, torch.ones_like(logits_fake))
        return 0.001 * adversarial_loss

看不出来任何问题,只能挨个试。这里只有两个变量:logits_faketorch.ones_like(logits_fake)。后者为常量,所以试着固定logits_fake,不让其参与训练,程序竟能运行了!

class AdversarialLoss(nn.Module):
    def __init__(self):
        super(AdversarialLoss, self).__init__()
        self.bec_loss = nn.BCELoss()

    def forward(self, logits_fake):
        # Adversarial Loss
        # !!! 问题在这,logits_fake加上detach后就可以正常运行
        adversarial_loss = self.bec_loss(logits_fake.detach(), torch.ones_like(logits_fake))
        return 0.001 * adversarial_loss

由此知道了被修改的变量是logits_fake。尽管程序可以运行了,但这样做不一定合理。类AdversarialLoss中没有对logits_fake进行修改,所以返回刚才的训练程序中。

for epoch in range(1, epochs + 1):
    for idx, (lr, hr) in enumerate(traindata_loader):
        lrs = lr.to(device)
        hrs = hr.to(device)

        # update the discriminator
        netD.zero_grad()
        logits_fake = netD(netG(lrs).detach())
        logits_real = netD(hrs)
        # Label smoothing
        real = (torch.rand(logits_real.size()) * 0.25 + 0.85).clone().detach().to(device)
        fake = (torch.rand(logits_fake.size()) * 0.15).clone().detach().to(device)
        d_loss = bce(logits_real, real) + bce(logits_fake, fake)
        d_loss.backward(retain_graph=True)
        # 这里进行的更新操作
        optimizerD.step()

        # update the generator
        netG.zero_grad()
        # !!!问题出错行
        g_loss = contentLoss(netG(lrs), hrs) + adversarialLoss(logits_fake)
        g_loss.backward()        
        optimizerG.step()

注意到Discriminator在出错行之前进行了更新操作,因此真相呼之欲出————optimizerD.step()logits_fake进行了修改。直接将其挪到倒数第二行即可,修改后代码为:

for epoch in range(1, epochs + 1):
    for idx, (lr, hr) in enumerate(traindata_loader):
        lrs = lr.to(device)
        hrs = hr.to(device)

        # update the discriminator
        netD.zero_grad()
        logits_fake = netD(netG(lrs).detach())
        logits_real = netD(hrs)
        # Label smoothing
        real = (torch.rand(logits_real.size()) * 0.25 + 0.85).clone().detach().to(device)
        fake = (torch.rand(logits_fake.size()) * 0.15).clone().detach().to(device)
        d_loss = bce(logits_real, real) + bce(logits_fake, fake)
        d_loss.backward(retain_graph=True)
        

        # update the generator
        netG.zero_grad()
        g_loss = contentLoss(netG(lrs), hrs) + adversarialLoss(logits_fake)
        g_loss.backward()   
        optimizerD.step()     
        optimizerG.step()

程序终于正常运行了,耶( •̀ ω •́ )y!

总结

原因:在计算生成器网络梯度之前先对判别器进行更新,修改了某些值,导致Generator网络的梯度计算失败。
解决方法:将Discriminator的更新步骤放到Generator的梯度计算步骤后面。

posted @ 2020-11-03 22:39  Js2Hou  阅读(30058)  评论(1编辑  收藏  举报