PyTorch 中的 ModuleList 和 Sequential
nn.ModuleList
可以把任意 nn.Module 的子类 (比如 nn.Conv2d, nn.Linear 之类的) 加到这个 list 里面,方法和 Python 自带的 list 一样,无非
是 extend,append 等操作。但不同于一般的 list,加入到 nn.ModuleList 里面的 module 是会注册到整个网络上的,同时 module 的
parameters 也会自动添加到整个网络中。
class net1(nn.Module): def __init__(self): super(net1, self).__init__() self.linears = nn.ModuleList([nn.Linear(10,10) for i in range(2)]) def forward(self, x): for m in self.linears: x = m(x) return x net = net1() print(net) # net1( # (modules): ModuleList( # (0): Linear(in_features=10, out_features=10, bias=True) # (1): Linear(in_features=10, out_features=10, bias=True) # ) # ) for param in net.parameters(): print(type(param.data), param.size()) # <class 'torch.Tensor'> torch.Size([10, 10]) # <class 'torch.Tensor'> torch.Size([10]) # <class 'torch.Tensor'> torch.Size([10, 10]) # <class 'torch.Tensor'> torch.Size([10])
我们可以看到,这个网络包含两个全连接层,他们的权重 (weithgs) 和偏置 (bias) 都在这个网络之内。接下来我们看看第二个网络,
它使用 Python 自带的 list:
class net2(nn.Module): def __init__(self): super(net2, self).__init__() self.linears = [nn.Linear(10,10) for i in range(2)] def forward(self, x): for m in self.linears: x = m(x) return x net = net2() print(net) # net2() print(list(net.parameters())) # []
显然,使用 Python 的 list 添加的全连接层和它们的 parameters 并没有自动注册到我们的网络中。当然,我们还是可以使用
forward 来计算输出结果。但是如果用 net2 实例化的网络进行训练的时候,因为这些层的 parameters 不在整个网络之中,所以其网络参数
也不会被更新。
好,看到这里,我们大致明白了 nn.ModuleList 是干什么的了:它是一个储存不同 module,并自动将每个 module 的 parameters 添加到网络
之中的容器。但是,我们需要注意到,nn.ModuleList 并没有定义一个网络,它只是将不同的模块储存在一起,这些模块之间并没有什么先后顺序可言。
nn.Sequential
nn.Sequential 是一个有序的容器,神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行,同时以神经网络模块为
元素的有序字典也可以作为传入参数。
不同于 nn.ModuleList,它已经实现的 forward 函数,而且里面的模块是按照顺序进行排列的,所以我们必须确保前一个模块的输
出大小和下一个模块的输入大小是一致的。
# Example of using Sequential model = nn.Sequential( nn.Conv2d(1,20,5), nn.ReLU(), nn.Conv2d(20,64,5), nn.ReLU() ) # Example of using Sequential with OrderedDict model = nn.Sequential(OrderedDict([ ('conv1', nn.Conv2d(1,20,5)), ('relu1', nn.ReLU()), ('conv2', nn.Conv2d(20,64,5)), ('relu2', nn.ReLU()) ]))
下面来看一下它的源码。在初始化函数 __init__
中,首先是 if 条件判断,如果传入的参数为 1 个,并且类型为 OrderedDict,通过
字典索引的方式将子模块添加到 self._module
中,否则,通过 for 循环遍历参数,将所有的子模块添加到 self._module
中。
注意:Sequential 模块的初始换函数没有异常处理,所以在写的时候要注意。
由于每一个神经网络模块都继承于 nn.Module,因此都会实现 __call__
与 forward
函数,所以 forward 函数中通过 for 循环依次
调用添加到现有模块中的子模块,最后输出经过所有神经网络层的结果。