pytorch 获取非叶子节点的grad

参考url: https://mathpretty.com/12509.html

在调试过程中, 有时候我们需要对中间变量梯度进行监控, 以确保网络的有效性, 这个时候我们需要打印出非叶节点的梯度, 为了实现这个目的, 我们可以通过两种手段进行, 分别是:

  • retain_grad()
  • hook

retain_grad()

retain_grad()显式地保存非叶节点的梯度, 当然代价就是会增加显存的消耗(对比hook函数的方法则是在反向计算时直接打印, 因此不会增加显存消耗.)

def forwrad(x, y, w1, w2):
    # 其中 x,y 为输入数据,w为该函数所需要的参数
    z_1 = torch.mm(w1, x)
    z_1.retain_grad()
    y_1 = torch.sigmoid(z_1)
    y_1.retain_grad()
    z_2 = torch.mm(w2, y_1)
    z_2.retain_grad()
    y_2 = torch.sigmoid(z_2)
    y_2.retain_grad()
    loss = 1/2*(((y_2 - y)**2).sum())
    loss.retain_grad()
    return loss, z_1, y_1, z_2, y_2

# 测试代码
x = torch.tensor([[1.0]])
y = torch.tensor([[1.0], [0.0]])
w1 = torch.tensor([[1.0], [2.0]], requires_grad=True)
w2 = torch.tensor([[3.0, 4.0], [5.0, 6.0]], requires_grad=True)
# w2 = torch.tensor([[3.0, 1.0], [1.0, 6.0]], requires_grad=True)
# 正向
loss, z_1, y_1, z_2, y_2 = forwrad(x, y, w1, w2)
# 反向
loss.backward()  # 反向传播,计算梯度
print(loss.grad)
print(y_2.grad)
print(z_2.grad)

hook的使用

使用retain_grad会消耗额外的显存, 我们可以使用hook在反向计算的时候进行保存. 还是上面的例子, 我们使用hook来完成.

# 我们可以定义一个hook来保存中间的变量
grads = {} # 存储节点名称与节点的grad
def save_grad(name):
    def hook(grad):
        grads[name] = grad
    return hook

def forwrad(x, y, w1, w2):
    # 其中 x,y 为输入数据,w为该函数所需要的参数
    z_1 = torch.mm(w1, x)
    y_1 = torch.sigmoid(z_1)
    z_2 = torch.mm(w2, y_1)
    y_2 = torch.sigmoid(z_2)
    loss = 1/2*(((y_2 - y)**2).sum())
    return loss, z_1, y_1, z_2, y_2

# 测试代码
x = torch.tensor([[1.0]])
y = torch.tensor([[1.0], [0.0]])
w1 = torch.tensor([[1.0], [2.0]], requires_grad=True)
w2 = torch.tensor([[3.0, 4.0], [5.0, 6.0]], requires_grad=True)
# 正向传播
loss, z_1, y_1, z_2, y_2 = forwrad(x, y, w1, w2)
# hook中间节点
z_1.register_hook(save_grad('z_1'))
y_1.register_hook(save_grad('y_1'))
z_2.register_hook(save_grad('z_2'))
y_2.register_hook(save_grad('y_2'))
loss.register_hook(save_grad('loss'))
# 反向传播
loss.backward()
print(grads['z_1'])
print(grads['y_1'])
print(grads['z_2'])
print(grads['y_2'])
print(grads['loss'])

 

posted @ 2022-04-14 20:49  dangxusheng  阅读(701)  评论(0编辑  收藏  举报