pytorch 的register_hook和register_backward_hook的介绍和实验

复制代码
class Classifier(nn.Module):
    def __init__(self, in_size, in_ch):
        super(Classifier, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_ch, 3, 3, 1, 1),
            nn.ReLU(),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(3, 6, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(6, 3, 3, 1, 1),
            nn.ReLU(),
        )
        self.fc = nn.Linear(3 * in_size * in_size, 1)

    def forward(self, x):
        x = self.layer1(x)
        identity = x
        x = self.layer2(x)
        x += identity
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x


def print_grad(grad):
    print('========= register_hook output:======== ')
    print(grad.size())
    print(grad)


def grad_hook(md, grad_in, grad_out):
    print('========= register_backward_hook output:======== ')
   # grad_in 包含: grad_bias, grad_x, grad_w 三者的梯度: (delta_bias, delta_x, delta_w)
# grad_out 是md整体的梯度,也等于grad_bias
print(grad_out[0].size()) print(grad_out[0]) torch.random.manual_seed(1000) if __name__ == '__main__': in_size, in_ch = 4, 1 x = torch.randn(1, 1, 4, 4) model = Classifier(in_size, in_ch) y_hat = model(x) y_gt = torch.Tensor([[1.5]]) crt = nn.MSELoss() print(y_hat) print('=======================') identity = [] for idx, (name, md) in enumerate(model._modules.items()): md.register_backward_hook(grad_hook) if isinstance(md, nn.Linear): x += identity[0] x = torch.flatten(x, 1) x = md(x) x.register_hook(print_grad) if idx == 0: identity.append(x) loss = crt(x, y_gt) loss.backward() print(x)
复制代码

 

posted @   dangxusheng  阅读(3899)  评论(0编辑  收藏  举报
编辑推荐:
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· winform 绘制太阳,地球,月球 运作规律
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
点击右上角即可分享
微信分享提示