pytorch 获取非叶子节点的grad
参考url: https://mathpretty.com/12509.html
在调试过程中, 有时候我们需要对中间变量梯度进行监控, 以确保网络的有效性, 这个时候我们需要打印出非叶节点的梯度, 为了实现这个目的, 我们可以通过两种手段进行, 分别是:
- retain_grad()
- hook
retain_grad()
retain_grad()显式地保存非叶节点的梯度, 当然代价就是会增加显存的消耗(对比hook函数的方法则是在反向计算时直接打印, 因此不会增加显存消耗.)
def forwrad(x, y, w1, w2): # 其中 x,y 为输入数据,w为该函数所需要的参数 z_1 = torch.mm(w1, x) z_1.retain_grad() y_1 = torch.sigmoid(z_1) y_1.retain_grad() z_2 = torch.mm(w2, y_1) z_2.retain_grad() y_2 = torch.sigmoid(z_2) y_2.retain_grad() loss = 1/2*(((y_2 - y)**2).sum()) loss.retain_grad() return loss, z_1, y_1, z_2, y_2 # 测试代码 x = torch.tensor([[1.0]]) y = torch.tensor([[1.0], [0.0]]) w1 = torch.tensor([[1.0], [2.0]], requires_grad=True) w2 = torch.tensor([[3.0, 4.0], [5.0, 6.0]], requires_grad=True) # w2 = torch.tensor([[3.0, 1.0], [1.0, 6.0]], requires_grad=True) # 正向 loss, z_1, y_1, z_2, y_2 = forwrad(x, y, w1, w2) # 反向 loss.backward() # 反向传播,计算梯度
print(loss.grad)
print(y_2.grad)
print(z_2.grad)
hook的使用
使用retain_grad会消耗额外的显存, 我们可以使用hook在反向计算的时候进行保存. 还是上面的例子, 我们使用hook来完成.
# 我们可以定义一个hook来保存中间的变量 grads = {} # 存储节点名称与节点的grad def save_grad(name): def hook(grad): grads[name] = grad return hook def forwrad(x, y, w1, w2): # 其中 x,y 为输入数据,w为该函数所需要的参数 z_1 = torch.mm(w1, x) y_1 = torch.sigmoid(z_1) z_2 = torch.mm(w2, y_1) y_2 = torch.sigmoid(z_2) loss = 1/2*(((y_2 - y)**2).sum()) return loss, z_1, y_1, z_2, y_2 # 测试代码 x = torch.tensor([[1.0]]) y = torch.tensor([[1.0], [0.0]]) w1 = torch.tensor([[1.0], [2.0]], requires_grad=True) w2 = torch.tensor([[3.0, 4.0], [5.0, 6.0]], requires_grad=True) # 正向传播 loss, z_1, y_1, z_2, y_2 = forwrad(x, y, w1, w2) # hook中间节点 z_1.register_hook(save_grad('z_1')) y_1.register_hook(save_grad('y_1')) z_2.register_hook(save_grad('z_2')) y_2.register_hook(save_grad('y_2')) loss.register_hook(save_grad('loss')) # 反向传播 loss.backward() print(grads['z_1']) print(grads['y_1']) print(grads['z_2']) print(grads['y_2']) print(grads['loss'])