python中`__call__`让类像函数一样被调用

在训练一个网络过程中,有下面代码

'''4.训练网络'''
print('开始训练')
for epoch in range(3):
    runing_loss = 0.0

    for i, data in enumerate(trainloader, 0):
        inputs, label = data  # 1.数据加载
        if device == 'gpu':
            inputs = inputs.cuda()
            label = label.cuda()
        optimizer.zero_grad()  # 2.初始化梯度
        output = Net(inputs)  # 3.计算前馈
        loss = criterion(output, label)  # 4.计算损失
        loss.backward()  # 5.计算梯度
        optimizer.step()  # 6.更新权值

        runing_loss += loss.item()
        if i % 20 == 19:
            print('epoch:', epoch, 'loss', runing_loss / 20)
            runing_loss = 0.0

print('训练完成')

发现在第三步计算前馈过程中output = Net(inputs)没有调用forward()方法,这是为什么?

以下参考Pytorch 模型中nn.Model 中的forward() 前向传播不调用 解释

在pytorch 中没有调用模型的forward()前向传播,只实列化后把参数传入。

class Module(nn.Module):
    def __init__(self):
        super(Module, self).__init__()
        # ......

    def forward(self, x):
        # ......
        return x

data = .....  #输入数据
# 实例化一个对象
module = Module()
#  前向传播 直接把输入传入实列化
module(data)  
#没有使用module.forward(data)

实际上module(data) 等价于module.forward(data)

等价的原因是因为 python calss 中的__call__ 可以让类像函数一样调用

当执行model(x)的时候,底层自动调用forward方法计算结果

class A():
    def __call__(self):
        print('i can be called like a function')
 
a = A()
a()
>>>i can be called like a function

在__call__ 里可调用其它的函数

class A():
    def __call__(self, param):

        print('我在__call__中,传入参数',param)

        res = self.forward(param) # <<<<<<<<<<<<<<<< 注意这里
        return res

    def forward(self, x):
        print('我在forward函数中,传入参数类型是值为: ',x)
        return x

a = A()
y = a('i')
#  >>> 我在__call__中,传入参数 i
#  >>>我在forward函数中,传入参数类型是值为:  i

print("传入的参数是:", y)
#  >>>传入的参数是: i

Reference

Pytorch 模型中nn.Model 中的forward() 前向传播不调用 解释

posted @ 2024-03-19 10:59  光辉233  阅读(1)  评论(0编辑  收藏  举报