pytorch使用总结

loss的获取

在看别人代码的时候发现都是

loss=net.loss
train_loss+=loss.data[0]#train_loss用于累加梯度

在想为什么不直接使用loss呢,因为pytorch使用Variable跟踪变量(4.0后合并为Tensor),也就是直接使用loss,那么pytorch认为其还在参与运算,其在一个batch后依旧存在于网络中而不是释放掉,所以资源占用会越来越大。

最新版本建议使用

loss.detach()

 

posted @ 2018-11-01 21:45  Luke_Ye  阅读(575)  评论(0编辑  收藏  举报