pytorch的四个hook函数
训练神经网络模型有时需要观察模型内部模块的输入输出,或是期望在不修改原始模块结构的情况下调整中间模块的输出,pytorch可以用hook回调函数来实现这一功能。主要使用四个hook注册函数:register_forward_hook、register_forward_pre_hook、register_full_backward_hook、register_full_backward_pre_hook。这四个函数可以被继承nn.Module的任意模块调用,传入hook函数并进行注册,从而在执行该模块的相应阶段调用hook函数实现所需功能。
register_forward_hook(self, hook, *, prepend, with_kwargs)
为模块注册一个在该模块前向传播之后执行的回调函数。
hook(module, args, output):需执行的回调函数对象,module为当前模块引用,args为当前模块前向传播输入,output为当前模块前向传播输出。可以返回修改后的output来修改该模块前向传播输出。
prepend:将该hook函数放在回调函数列表最前面,从而最先执行,否则放在队列最后。
with_kwargs:hook函数是否传入关键字参数,如果为True,则hook额外增加关键字参数,变为 hook(module, args, kwargs, output)。注意!如果with_kwargs=False,模块传入的关键字参数将不会被捕获,坑了我一个下午。
register_forward_hook注册函数本身返回一个handle句柄,可执行handle.remove()将注册的该hook函数移除。
register_forward_pre_hook(self, hook, *, prepend, with_kwargs)
为模块注册一个在该模块前向传播之前执行的回调函数。
hook(module, args):args为该模块前向传播输入。可以返回修改后的args来修改该模块前向传播输入。
其它参数、特性与前面一致。
register_full_backward_hook(self, hook, prepend)
为模块注册一个在该模块反向传播之后执行的回调函数。
hook(module, grad_input, grad_output):grad_input与grad_output分别为该模块前向传播输入和输出的梯度。可以返回修改后的grad_input来修改该模块前向传播输入的梯度。
register_full_backward_pre_hook(self, hook, prepend)
为模块注册一个在该模块反向传播之前执行的回调函数。
hook(module, grad_output):grad_output为该模块前向传播输出的梯度。可以返回修改后的grad_output来修改这一梯度。