Pytorch的类(nn.Module的子类)中的forward函数
使用
直接通过类的实例对象就可以向类中的forward函数进行参数的传递(当然也可以通过调用forward函数进行传参)
import torch.nn as nn
class MyModule(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
data1 = 1
data2 = 2
module = MyModule()
x1 = module(data1) # 不需要显示调用forward函数就可以传递参数
x2 = module.forward(data2)
print(x1)
print(x2)
>> 1
>> 2
解释
nn.Module() 中包含了 __call__
函数;
实现了 __call__
函数的类,其类实例是一个可调用的对象,其可以简化对于类中某些方法的调用(写在__call__
中的方法),模糊了实例对象和类成员函数的区别。使用类实例 module() 时 就相当于 module.__call__(),如果在 __call()__ 中写上函数,就可以直接通过类实例对象传参调用了。
而在 nn.Module() 中的 __call__
函数中调用了 forward() 函数,
...
# 例子 #
def __call__(self, param):
res = self.forward(param)
return res
...
由于继承关系,对于MyModule(nn.Module) 类 同样具备了 __call__
函数的功能,即可以通过类实例module 直接 调用 forward 并传参。
本文作者:jacknie23
本文链接:https://www.cnblogs.com/jack-nie-23/p/16506630.html
版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步