pytorch 的register_hook和register_backward_hook的介绍和实验

class Classifier(nn.Module):
    def __init__(self, in_size, in_ch):
        super(Classifier, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_ch, 3, 3, 1, 1),
            nn.ReLU(),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(3, 6, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(6, 3, 3, 1, 1),
            nn.ReLU(),
        )
        self.fc = nn.Linear(3 * in_size * in_size, 1)

    def forward(self, x):
        x = self.layer1(x)
        identity = x
        x = self.layer2(x)
        x += identity
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x


def print_grad(grad):
    print('========= register_hook output:======== ')
    print(grad.size())
    print(grad)


def grad_hook(md, grad_in, grad_out):
    print('========= register_backward_hook output:======== ')
   # grad_in 包含: grad_bias, grad_x, grad_w 三者的梯度: (delta_bias, delta_x, delta_w)
# grad_out 是md整体的梯度,也等于grad_bias
print(grad_out[0].size()) print(grad_out[0]) torch.random.manual_seed(1000) if __name__ == '__main__': in_size, in_ch = 4, 1 x = torch.randn(1, 1, 4, 4) model = Classifier(in_size, in_ch) y_hat = model(x) y_gt = torch.Tensor([[1.5]]) crt = nn.MSELoss() print(y_hat) print('=======================') identity = [] for idx, (name, md) in enumerate(model._modules.items()): md.register_backward_hook(grad_hook) if isinstance(md, nn.Linear): x += identity[0] x = torch.flatten(x, 1) x = md(x) x.register_hook(print_grad) if idx == 0: identity.append(x) loss = crt(x, y_gt) loss.backward() print(x)

 

posted @ 2020-03-25 22:34  dangxusheng  阅读(3815)  评论(0编辑  收藏  举报