pytorch使用总结
loss的获取
在看别人代码的时候发现都是
loss=net.loss train_loss+=loss.data[0]#train_loss用于累加梯度
在想为什么不直接使用loss呢,因为pytorch使用Variable跟踪变量(4.0后合并为Tensor),也就是直接使用loss,那么pytorch认为其还在参与运算,其在一个batch后依旧存在于网络中而不是释放掉,所以资源占用会越来越大。
最新版本建议使用
loss.detach()
在看别人代码的时候发现都是
loss=net.loss train_loss+=loss.data[0]#train_loss用于累加梯度
在想为什么不直接使用loss呢,因为pytorch使用Variable跟踪变量(4.0后合并为Tensor),也就是直接使用loss,那么pytorch认为其还在参与运算,其在一个batch后依旧存在于网络中而不是释放掉,所以资源占用会越来越大。
最新版本建议使用
loss.detach()