pytorch小知识(01):forward方法
我们都知道torch.nn.Module在被继承时要求我们改写两个方法,一个是__init__,一个是forward。前者用于定义层,后者用于定义前向计算的流程。但是当我们在实际使用一个网络时,我们不会使用forward这个方法进行计算,而是进行如下的操作:
可以看到我们直接使用了net来接收实例,然后在进行对标签的计算时,直接使用了y=net(x)。
实际上,这是由于nn.Module自带的__call__魔法方法在起作用。这里,当我们调用net时,__call__会帮我们自动调用forward方法。所以,我们不管是使用net(x),net.forward(x)还是net.__call__(x),得到的结果都是对x进行了一轮计算。
以上,就是今天的pytorch小知识。
点击查看代码
import torch
import torch.nn as nn
from torchsummary import summary
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.layer1 = nn.Linear(300,200)
self.layer2 = nn.Linear(200,200)
self.out = nn.Linear(200,10)
nn.init.kaiming_normal_(self.layer1.weight)
nn.init.xavier_normal_(self.layer2.weight)
nn.init.kaiming_uniform_(self.out.weight)
def forward(self, x):
x = self.layer1(x)
x = torch.relu(x)
x = self.layer2(x)
x = torch.sigmoid(x)
x = self.out(x)
x = torch.softmax(x, dim=-1)
return x
if __name__ == '__main__':
device = torch.device('cuda:0')
net = Model().to(device)
data = torch.randn(500,300).to(device)
y = net(data)
print('y:',y)
summary(model=net,input_size=(300,),batch_size=500)
print("======查看模型参数w和b======")
for name, parameter in net.named_parameters():
print(name, parameter)
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· Docker 太简单,K8s 太复杂?w7panel 让容器管理更轻松!