『PyTorch』第十六弹_hook技术
由于pytorch会自动舍弃图计算的中间结果,所以想要获取这些数值就需要使用钩子函数。
钩子函数包括Variable的钩子和nn.Module钩子,用法相似。
一、register_hook
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | import torch from torch.autograd import Variable grad_list = [] def print_grad(grad): grad_list.append(grad) x = Variable(torch.randn( 2 , 1 ), requires_grad = True ) y = x + 2 z = torch.mean(torch. pow (y, 2 )) lr = 1e - 3 y.register_hook(print_grad) z.backward() x.data - = lr * x.grad.data print (grad_list) |
二、register_forward_hook
& register_backward_hook
这两个函数的功能类似于variable函数的register_hook
,可在module前向传播或反向传播时注册钩子。
每次前向传播执行结束后会执行钩子函数(hook)。前向传播的钩子函数具有如下形式:hook(module, input, output) -> None
,而反向传播则具有如下形式:hook(module, grad_input, grad_output) -> Tensor or None
。
钩子函数不应修改输入和输出,并且在使用后应及时删除,以避免每次都运行钩子增加运行负载。钩子函数主要用在获取某些中间结果的情景,如中间某一层的输出或某一层的梯度。这些结果本应写在forward函数中,但如果在forward函数中专门加上这些处理,可能会使处理逻辑比较复杂,这时候使用钩子技术就更合适一些。下面考虑一种场景,有一个预训练好的模型,需要提取模型的某一层(不是最后一层)的输出作为特征进行分类,但又不希望修改其原有的模型定义文件,这时就可以利用钩子函数。下面给出实现的伪代码。
model = VGG()
features = t.Tensor()
def hook(module, input, output):
'''把这层的输出拷贝到features中'''
features.copy_(output.data)
handle = model.layer8.register_forward_hook(hook)
_ = model(input)
# 用完hook后删除
handle.remove()
测试LeNet网络
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | import torch as t import torch.nn as nn import torch.nn.functional as F class LeNet(nn.Module): def __init__( self ): super (LeNet, self ).__init__() self .conv1 = nn.Conv2d( 1 , 6 , 5 ) self .conv2 = nn.Conv2d( 6 , 16 , 5 ) self .fc1 = nn.Linear( 16 * 5 * 5 , 120 ) self .fc2 = nn.Linear( 120 , 84 ) self .fc3 = nn.Linear( 84 , 10 ) def forward( self ,x): x = F.max_pool2d(F.relu( self .conv1(x)),( 2 , 2 )) x = F.max_pool2d(F.relu( self .conv2(x)), 2 ) x = x.view(x.size()[ 0 ], - 1 ) x = F.relu( self .fc1(x)) x = F.relu( self .fc2(x)) x = self .fc3(x) return x |
先模拟一下单次的向前传播,
1 2 3 | net = LeNet() img = t.autograd.Variable((t.arange( 32 * 32 * 1 ).view( 1 , 1 , 32 , 32 ))) net(img) |
Variable containing: Columns 0 to 7 27.6373 -13.4590 23.0988 -16.4491 -8.8454 -15.6934 -4.8512 1.3490 Columns 8 to 9 3.7801 -15.9396 [torch.FloatTensor of size 1x10]
仿照上面示意,进行钩子注册,获取第一卷积层输出结果,
1 2 3 4 5 6 7 8 | def hook(module, inputdata, output): '''把这层的输出拷贝到features中''' print (output.data) handle = net.conv2.register_forward_hook(hook) net(img) # 用完hook后删除 handle.remove() |
……
……
[torch.FloatTensor of size 1x16x10x10]
看看hook能识别什么
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 | import torch from torch import nn import torch.functional as F from torch.autograd import Variable def for_hook(module, input , output): print (module) for val in input : print ( "input val:" ,val) for out_val in output: print ( "output val:" , out_val) class Model(nn.Module): def __init__( self ): super (Model, self ).__init__() def forward( self , x): return x + 1 model = Model() x = Variable(torch.FloatTensor([ 1 ]), requires_grad = True ) handle = model.register_forward_hook(for_hook) print (model(x)) handle.remove() |
可见对于目标层,其输入输出都可以获取到,
Model( ) input val: Variable containing: 1 [torch.FloatTensor of size 1] output val: Variable containing: 2 [torch.FloatTensor of size 1] Variable containing: 2 [torch.FloatTensor of size 1]
标签:
PyTorch
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 10年+ .NET Coder 心语,封装的思维:从隐藏、稳定开始理解其本质意义
· .NET Core 中如何实现缓存的预热?
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· 10年+ .NET Coder 心语 ── 封装的思维:从隐藏、稳定开始理解其本质意义
· 地球OL攻略 —— 某应届生求职总结
· 提示词工程——AI应用必不可少的技术
· Open-Sora 2.0 重磅开源!
· 字符编码:从基础到乱码解决