PyTorch查看Sequential中tensor维度的方法

需要定义一个回调函数:

def get_features_hook(self, input, output):
    print("hook", output.data.cpu().numpy().shape)

然后对需要查看的层注册钩子:

handle = self.model.fc_loc[2].register_forward_hook(get_features_hook)

在查看完后移除钩子:

handle.remove()
posted @ 2018-08-07 01:55  木易修  阅读(972)  评论(0编辑  收藏  举报