ruijiege

  博客园 :: 首页 :: 博问 :: 闪存 :: 新随笔 :: 联系 :: 订阅 订阅 :: 管理 ::
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.cov1 = nn.Conv2d(1,3,3)
        self.cov2 = nn.Conv2d(3,2,3)
        
    def backword(self , x):
        print("model backword fisrt",x.shape)
        x = F.relu(self.cov1(x))
        x = F.relu(self.cov2(x))
        print("model backword end",x.shape)
        return x
    
def before_hook(model,input):
    print("brefore hook",model," input ",input[0].shape)
    return torch.zeros(1, 1, 7, 7)

model = Model()
hook = model.register_forward_pre_hook(before_hook)
input = torch.zeros(1,1,5,5)
model(input)
hook.remove()

 

posted on 2022-10-29 08:24  哦哟这个怎么搞  阅读(52)  评论(0编辑  收藏  举报