PyTorch之Sequential
classNet中首先通过super函数继承torch.nn.Module模块的构造方法,再通过添加属性的方式搭建神经网络各层的结构信息,在forward方法中完善神经网络各层之间的连接信息,然后再通过定义Net类对象的方式完成对神经网络结构的构建。
#!/usr/bin/env python2 # -*- coding: utf-8 -*- import torch import torch.nn.functional as F #旧方法搭建神经网络 class Net(torch.nn.Module): def __init__(self, n_feature, n_hidden, n_output): super(Net, self).__init__() self.hidden = torch.nn.Linear(n_feature, n_hidden) self.predict = torch.nn.Linear(n_hidden, n_output) def forward(self, x): x = F.relu(self.hidden(x)) x = self.predict(x) return x net1 = Net(1, 10, 1) # 这是我们用这种方式搭建的 net1 print net1 """ Net ( (hidden): Linear (1 -> 10) (predict): Linear (10 -> 1) ) """
构建神经网络的另一个方法,也可以说是快速构建方法,就是通过torch.nn.Sequential,直接完成对神经网络的建立。
#!/usr/bin/env python2 # -*- coding: utf-8 -*- import torch import torch.nn.functional as F #新方法搭建神经网络 net2 = torch.nn.Sequential( torch.nn.Linear(1, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1) ) #net2比net1多显示了一个激活函数 print net2 """ Sequential ( (0): Linear (1 -> 10) (1): ReLU () (2): Linear (10 -> 1) ) """