pytorch 的register_hook和register_backward_hook的介绍和实验
class Classifier(nn.Module): def __init__(self, in_size, in_ch): super(Classifier, self).__init__() self.layer1 = nn.Sequential( nn.Conv2d(in_ch, 3, 3, 1, 1), nn.ReLU(), ) self.layer2 = nn.Sequential( nn.Conv2d(3, 6, 3, 1, 1), nn.ReLU(), nn.Conv2d(6, 3, 3, 1, 1), nn.ReLU(), ) self.fc = nn.Linear(3 * in_size * in_size, 1) def forward(self, x): x = self.layer1(x) identity = x x = self.layer2(x) x += identity x = torch.flatten(x, 1) x = self.fc(x) return x def print_grad(grad): print('========= register_hook output:======== ') print(grad.size()) print(grad) def grad_hook(md, grad_in, grad_out): print('========= register_backward_hook output:======== ')
# grad_in 包含: grad_bias, grad_x, grad_w 三者的梯度: (delta_bias, delta_x, delta_w)
# grad_out 是md整体的梯度,也等于grad_bias
print(grad_out[0].size()) print(grad_out[0]) torch.random.manual_seed(1000) if __name__ == '__main__': in_size, in_ch = 4, 1 x = torch.randn(1, 1, 4, 4) model = Classifier(in_size, in_ch) y_hat = model(x) y_gt = torch.Tensor([[1.5]]) crt = nn.MSELoss() print(y_hat) print('=======================') identity = [] for idx, (name, md) in enumerate(model._modules.items()): md.register_backward_hook(grad_hook) if isinstance(md, nn.Linear): x += identity[0] x = torch.flatten(x, 1) x = md(x) x.register_hook(print_grad) if idx == 0: identity.append(x) loss = crt(x, y_gt) loss.backward() print(x)
分类:
深度学习
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· winform 绘制太阳,地球,月球 运作规律
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人