pytorch学习笔记——训练时显存逐渐增加,几个epoch后out-of-memory
问题起因:笔者想把别人的torch的代码复制到笔者的代码框架下,从而引起的显存爆炸问题
该bug在困扰了笔者三天的情况下,和学长一同解决了该bug,故在此记录这次艰辛的debug之路。
尝试思路1:检查是否存在保留loss的情况下是否使用了 item() 取值,经检查,并没有
尝试思路2:按照网上的说法,添加两行下面的代码:
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
实测发现并没有用。
尝试思路3:及时删除临时变量和清空显存的cache,例如每次训练一个batch就清除模型的输入输出。
del inputs,loss gc.collect() torch.cuda.empty_cache()
这样确实使得模型能够多训练几个epoch,但依旧没有解决显存持续增长的问题,而且由于频繁使用torch.cuda.empty_cache(),导致模型一个epoch的训练时长翻了3倍多。
尝试思路4:重新核对原模型代码,打印模型中所有parameters和register_buffer的require_grad,终于发现是因为模型中的某个register_buffer在训练过程中,它的require_grad本应该为False,然而迁移到我代码上的实际训练过程中变成了True,而这个buffer的占用数据空间也不大,可能是因为变为True之后,导致在显存中一直被保留,从而最终导致显存溢出。再将那个buffer在forward函数里的操作放在torch.no_grad()上下文中,问题解决!
总结:如果训练时显存占用持续增加,需要谨慎的检查forward函数中的操作,尤其是在编写复杂代码的时候,更需要更细致的检查!