PyTorch查看Sequential中tensor维度的方法
需要定义一个回调函数:
def get_features_hook(self, input, output):
print("hook", output.data.cpu().numpy().shape)
然后对需要查看的层注册钩子:
handle = self.model.fc_loc[2].register_forward_hook(get_features_hook)
在查看完后移除钩子:
handle.remove()