pytorch学习笔记——训练时显存逐渐增加,几个epoch后out-of-memory

问题起因:笔者想把别人的torch的代码复制到笔者的代码框架下,从而引起的显存爆炸问题

该bug在困扰了笔者三天的情况下,和学长一同解决了该bug,故在此记录这次艰辛的debug之路。

尝试思路1:检查是否存在保留loss的情况下是否使用了 item() 取值,经检查,并没有

尝试思路2:按照网上的说法,添加两行下面的代码:

torch.backends.cudnn.enabled = True

torch.backends.cudnn.benchmark = True

实测发现并没有用。

尝试思路3:及时删除临时变量和清空显存的cache,例如每次训练一个batch就清除模型的输入输出。

del inputs,loss
gc.collect()
torch.cuda.empty_cache()

这样确实使得模型能够多训练几个epoch,但依旧没有解决显存持续增长的问题,而且由于频繁使用torch.cuda.empty_cache(),导致模型一个epoch的训练时长翻了3倍多

尝试思路4:重新核对原模型代码,打印模型中所有parameters和register_buffer的require_grad,终于发现是因为模型中的某个register_buffer在训练过程中,它的require_grad本应该为False,然而迁移到我代码上的实际训练过程中变成了True,而这个buffer的占用数据空间也不大,可能是因为变为True之后,导致在显存中一直被保留,从而最终导致显存溢出。再将那个buffer在forward函数里的操作放在torch.no_grad()上下文中,问题解决!

 

 

 

总结:如果训练时显存占用持续增加,需要谨慎的检查forward函数中的操作,尤其是在编写复杂代码的时候,更需要更细致的检查

 

posted @ 2022-03-30 23:21  ISGuXing  阅读(8671)  评论(0编辑  收藏  举报