pytorch调用模型输出变量的一个坑-RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed)
报错现象
pytorch中在前向传播的outputs = model(input_data)中的outputs,是不能随便复制调用的,例如
无论是直接赋值、还是使用深拷贝,还是使用torch的clone,都会报这个错
_outputs = outputs
或者
_outputs = copy.deepcopy(outputs)
或者
_outputs = outputs.clone()
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
解决方法
detach() 用于从计算图中分离出张量,这样在进行替换操作时不会影响梯度计算。
_outputs = outputs.detach()
这时候就正常了
本文来自博客园,作者:JaxonYe,转载请注明原文链接:https://www.cnblogs.com/yechangxin/articles/18237595
侵权必究