pytorch 的register_hook和register_backward_hook的介绍和实验
class Classifier(nn.Module): def __init__(self, in_size, in_ch): super(Classifier, self).__init__() self.layer1 = nn.Sequential( nn.Conv2d(in_ch, 3, 3, 1, 1), nn.ReLU(), ) self.layer2 = nn.Sequential( nn.Conv2d(3, 6, 3, 1, 1), nn.ReLU(), nn.Conv2d(6, 3, 3, 1, 1), nn.ReLU(), ) self.fc = nn.Linear(3 * in_size * in_size, 1) def forward(self, x): x = self.layer1(x) identity = x x = self.layer2(x) x += identity x = torch.flatten(x, 1) x = self.fc(x) return x def print_grad(grad): print('========= register_hook output:======== ') print(grad.size()) print(grad) def grad_hook(md, grad_in, grad_out): print('========= register_backward_hook output:======== ')
# grad_in 包含: grad_bias, grad_x, grad_w 三者的梯度: (delta_bias, delta_x, delta_w)
# grad_out 是md整体的梯度,也等于grad_bias
print(grad_out[0].size()) print(grad_out[0]) torch.random.manual_seed(1000) if __name__ == '__main__': in_size, in_ch = 4, 1 x = torch.randn(1, 1, 4, 4) model = Classifier(in_size, in_ch) y_hat = model(x) y_gt = torch.Tensor([[1.5]]) crt = nn.MSELoss() print(y_hat) print('=======================') identity = [] for idx, (name, md) in enumerate(model._modules.items()): md.register_backward_hook(grad_hook) if isinstance(md, nn.Linear): x += identity[0] x = torch.flatten(x, 1) x = md(x) x.register_hook(print_grad) if idx == 0: identity.append(x) loss = crt(x, y_gt) loss.backward() print(x)