import torch import torch.nn as nn import numpy as np import torch.nn.functional as F class Model(nn.Module): def __init__(self): super().__init__() self.cov1 = nn.Conv2d(1,3,3) self.cov2 = nn.Conv2d(3,2,3) def backword(self , x): print("model backword fisrt",x.shape) x = F.relu(self.cov1(x)) x = F.relu(self.cov2(x)) print("model backword end",x.shape) return x def before_hook(model,input): print("brefore hook",model," input ",input[0].shape) return torch.zeros(1, 1, 7, 7) model = Model() hook = model.register_forward_pre_hook(before_hook) input = torch.zeros(1,1,5,5) model(input) hook.remove()