torch.nn ------ 剩余容器
torch.nn ------ 剩余容器
一、Sequential
一个序列容器。模块将按照它们在构造函数中传递的顺序添加到其中。或者传入OrderedDict
模块(value是子模块)。该容器的forward()
方法接受任何输入并将其转发到它包含的第一个模块。然后,它将每个后续模块的输出顺序“链接”到输入,最后返回最后一个模块的输出。
Sequential可以手动调用提供的模块序列,它允许将整个容器视为单个模块,这样,对Sequential执行可以理解为顺序执行它存储的每个模块(每个模块都是Sequential的注册子模块)。
顺序列表和ModuleList有什么区别?ModuleList就是它听起来的样子:一个存储模块的列表!另一方面,按顺序排列的层和级联方式连接。
model = nn.Sequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
)
model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1,20,5)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(20,64,5)),
('relu2', nn.ReLU())
]))
这里Sequential序列可以看作一个模块进行执行,这是它的最大特点和优势!
二、ModuleList
构建一个子模型列表,这个列表不能直接执行!如:
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
def forward(self, x):
# ModuleList can act as an iterable, or be indexed using ints
for i, l in enumerate(self.linears):
x = self.linears[i // 2](x) + l(x)
return x
常见的列表方法这里也是使用的,有:
- append: 在末尾追加一个子模型,参数module其参数类型为nn.Module
- extend: 在末尾追加多个子模型
- inset: 在指定索引处插入子模型
三、ModuleDict
构建一个子模型字典,注意这是有序字典。参数modules是一个可迭代对象,如字典,或其他键值对数据。
案例:
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.choices = nn.ModuleDict({
'conv': nn.Conv2d(10, 10, 3),
'pool': nn.MaxPool2d(3)
})
self.activations = nn.ModuleDict([
['lrelu', nn.LeakyReLU()],
['prelu', nn.PReLU()]
])
def forward(self, x, choice, act):
x = self.choices[choice](x)
x = self.activations[act](x)
return x
关于字典的方法实现了如下功能:
- clear(): 清空所有子模块
- items(): 迭代返回子模块名和子模块
- keys(): 返回所有子模块名
- pop(key): 删除子模块“key”
- update(modules): 更新子模型
- values(): 返回所有子模块
测试update方法:
>>> net = nn.ModuleDict({"conv": nn.Conv2d(12,64,3), "pool": nn.MaxPool2d(3)})
>>> net
ModuleDict(
(conv): Conv2d(12, 64, kernel_size=(3, 3), stride=(1, 1))
(pool): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
)
>>> net.update({"conv1": nn.Conv2d(64, 128, 3), "pool": nn.MaxPool2d(2)})
ModuleDict(
(conv): Conv2d(12, 64, kernel_size=(3, 3), stride=(1, 1))
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
)
注意新加的子模块默认还是在最后!
四、ParameterList
ParameterList可以像常规Python列表一样被索引,但它包含的参数已正确注册,并且所有模块方法是可使用的。参数parameters是一个可迭代的参数封装,如:
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.params = nn.ParameterList(
[nn.Parameter(torch.randn(10, 10)) for i in range(10)]
)
def forward(self, x):
# ParameterList can act as an iterable, or be indexed using ints
for i, p in enumerate(self.params):
x = self.params[i // 2].mm(x) + p.mm(x)
return x
它实现了append和extend方法!
五、ParameterDict
与ModuleDict类似,这也是一个有序的字典。使用案例:
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.params = nn.ParameterDict({
'left': nn.Parameter(torch.randn(5, 10)),
'right': nn.Parameter(torch.randn(5, 10))
})
def forward(self, x, choice):
x = self.params[choice].mm(x)
return x
关于参数字典实现的方法有:
- clear(): 清空所有参数项目
- copy(): 返回此ParameterDict的副本
- fromkeys(keys, default=None): 通过传入的关键字返回一个新的ParameterDict,default标识为所有keys设置的值
- get(key, default=None):
- items(): 迭代返回参数字典key、value对迭代器
- keys(): 返回参数字典所有的关键字
- pop(key): 删除参数“key”
- popitem(): 删除并返回最后一次插入的参数
- setdefault(key, default=None): 如果当前参数字典有关键字key,则返回key对应的参数;如果没有就插入key,它对应的参数是default的值,并返回default的值。
- update(parameters): 根据parameters中的key-value对进行参数字典的更新,parameters的顺序不会改变参数字典的顺序,新加的参数默认追加。具体可以参考ModuleDict的对应方法。
- values(): 返回所有参数对象
六、对于Module的全局钩子
- register_module_forward_pre_hook
- register_module_forward_hook
- register_module_backward_hook
- register_module_full_backward_hook
>>> def elfin_hook(module, in_data):
... print(list(module.named_children()))
>>> nn.modules.module.register_module_forward_pre_hook(elfin_hook)
<torch.utils.hooks.RemovableHandle at 0x7f43743af1d0>
这里我们注册钩子的方式和环境变量的声明是类似的!需要进行如上的操作,在模型执行过程中,钩子自然会被调用!
完!
清澈的爱,只为中国