PyTorch 模型构造
记录几种模型构造的方法:
继承Module
类来构造模型
Module
是所有神经网络模块的基类,通过继承它来得到我们需要的模型,通常我们需要重载Module
类的__init__
函数和forward
函数。
实例
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
利用Module
的子类
在Pytorch中实现了继承自Module
的可以方便构造模型的类,有Sequential
、ModuleList
、ModuleDict
等
-
使用
Sequential
当模型的前向计算为简单串联各个层的计算时,
Sequential
类可以通过更加简单的方式定义模型。这正是Sequential
类的目的:它可以接收一个子模块的有序字典(OrderedDict)或者一系列子模块作为参数来逐一添加Module
的实例,而模型的前向计算就是将这些实例按添加的顺序逐一计算。这里实现一个与
Sequential
具有相似功能的MySequential
类class MySequential(nn.Module): from collections import OrderedDict def __init__(self, *args): super(MySequential, self).__init__() if len(args) == 1 and isinstance(args[0], OrderedDict): # 如果传入的是一个OrderedDict for key, module in args[0].items(): self.add_module(key, module) # add_module方法会将module添加进self._modules(一个OrderedDict) else: # 传入的是一些Module for idx, module in enumerate(args): self.add_module(str(idx), module) def forward(self, input): # self._modules返回一个 OrderedDict,保证会按照成员添加时的顺序遍历成员 for module in self._modules.values(): input = module(input) return input
-
使用
ModuleList
将子模块放在一个列表(
list
)之中
ModuleList
可以像常规的Python list一样执行append()
、extend()
操作,有一些区别在于ModuleList
中的所有模块的参数会被自动地添加到整个网络之中实例
net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()]) net.append(nn.Linear(256, 10)) # # 类似List的append操作 print(net[-1]) # 类似List的索引访问 print(net)
虽然
Sequential
和ModuleList
都可以列表化构造网络,但二者存在区别:ModuleList
仅仅是一个储存各种模块的列表,这些模块之间没有联系也没有顺序(所以不用保证相邻层的输入输出维度 匹配),而且没有实现forward
功能(需要自己实现)。Sequential
内的模块需要按照顺序排列,要保证相邻层的输入输出大小相匹配,内部forward
功能已经实现。ModuleList
的出现可以让网络定义前向传播时更加灵活:class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)]) def forward(self, x): # ModuleList can act as an iterable, or be indexed using ints for i, l in enumerate(self.linears): x = self.linears[i // 2](x) + l(x) return x
-
使用
ModuleDict
ModuleDict
接收一个子模块的字典作为输入, 然后也可以类似字典那样进行添加访问操作:net = nn.ModuleDict({ 'linear': nn.Linear(784, 256), 'act': nn.ReLU(), }) net['output'] = nn.Linear(256, 10) # 添加 print(net['linear']) # 访问 print(net.output) print(net) # net(torch.zeros(1, 784)) # 会报NotImplementedError
和
ModuleList
一样,使用ModuleDict
时同样需要自己定义forward
参考: