pytorch调用模型输出变量的一个坑-RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed)

报错现象

pytorch中在前向传播的outputs = model(input_data)中的outputs,是不能随便复制调用的,例如
无论是直接赋值、还是使用深拷贝,还是使用torch的clone,都会报这个错

_outputs = outputs
或者
_outputs = copy.deepcopy(outputs)
或者
_outputs = outputs.clone()

image
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

解决方法

detach() 用于从计算图中分离出张量,这样在进行替换操作时不会影响梯度计算。

_outputs = outputs.detach()

这时候就正常了

posted @ 2024-06-07 17:27  JaxonYe  阅读(322)  评论(0编辑  收藏  举报