torch.nn ------ 剩余容器

torch.nn ------ 剩余容器

作者:elfin   参考资料来源:torch.nn


Top---Bottom

一、Sequential

​ 一个序列容器。模块将按照它们在构造函数中传递的顺序添加到其中。或者传入OrderedDict模块(value是子模块)。该容器的forward()方法接受任何输入并将其转发到它包含的第一个模块。然后,它将每个后续模块的输出顺序“链接”到输入,最后返回最后一个模块的输出。

​ Sequential可以手动调用提供的模块序列,它允许将整个容器视为单个模块,这样,对Sequential执行可以理解为顺序执行它存储的每个模块(每个模块都是Sequential的注册子模块)。

​ 顺序列表和ModuleList有什么区别?ModuleList就是它听起来的样子:一个存储模块的列表!另一方面,按顺序排列的层和级联方式连接。

model = nn.Sequential(
          nn.Conv2d(1,20,5),
          nn.ReLU(),
          nn.Conv2d(20,64,5),
          nn.ReLU()
        )
model = nn.Sequential(OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ]))

这里Sequential序列可以看作一个模块进行执行,这是它的最大特点和优势!


Top---Bottom

二、ModuleList

构建一个子模型列表,这个列表不能直接执行!如:

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])

    def forward(self, x):
        # ModuleList can act as an iterable, or be indexed using ints
        for i, l in enumerate(self.linears):
            x = self.linears[i // 2](x) + l(x)
        return x

常见的列表方法这里也是使用的,有:

  • append: 在末尾追加一个子模型,参数module其参数类型为nn.Module
  • extend: 在末尾追加多个子模型
  • inset: 在指定索引处插入子模型

Top---Bottom

三、ModuleDict

构建一个子模型字典,注意这是有序字典。参数modules是一个可迭代对象,如字典,或其他键值对数据。

案例

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.choices = nn.ModuleDict({
                'conv': nn.Conv2d(10, 10, 3),
                'pool': nn.MaxPool2d(3)
        })
        self.activations = nn.ModuleDict([
                ['lrelu', nn.LeakyReLU()],
                ['prelu', nn.PReLU()]
        ])

    def forward(self, x, choice, act):
        x = self.choices[choice](x)
        x = self.activations[act](x)
        return x

关于字典的方法实现了如下功能:

  • clear(): 清空所有子模块
  • items(): 迭代返回子模块名和子模块
  • keys(): 返回所有子模块名
  • pop(key): 删除子模块“key”
  • update(modules): 更新子模型
  • values(): 返回所有子模块

测试update方法:

>>> net = nn.ModuleDict({"conv": nn.Conv2d(12,64,3), "pool": nn.MaxPool2d(3)})
>>> net
ModuleDict(
  (conv): Conv2d(12, 64, kernel_size=(3, 3), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
)
>>> net.update({"conv1": nn.Conv2d(64, 128, 3), "pool": nn.MaxPool2d(2)})
ModuleDict(
  (conv): Conv2d(12, 64, kernel_size=(3, 3), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
)

注意新加的子模块默认还是在最后!


Top---Bottom

四、ParameterList

​ ParameterList可以像常规Python列表一样被索引,但它包含的参数已正确注册,并且所有模块方法是可使用的。参数parameters是一个可迭代的参数封装,如:

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.params = nn.ParameterList(
            [nn.Parameter(torch.randn(10, 10)) for i in range(10)]
        )

    def forward(self, x):
        # ParameterList can act as an iterable, or be indexed using ints
        for i, p in enumerate(self.params):
            x = self.params[i // 2].mm(x) + p.mm(x)
        return x

它实现了append和extend方法!


Top---Bottom

五、ParameterDict

与ModuleDict类似,这也是一个有序的字典。使用案例:

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.params = nn.ParameterDict({
                'left': nn.Parameter(torch.randn(5, 10)),
                'right': nn.Parameter(torch.randn(5, 10))
        })

    def forward(self, x, choice):
        x = self.params[choice].mm(x)
        return x

关于参数字典实现的方法有:

  • clear(): 清空所有参数项目
  • copy(): 返回此ParameterDict的副本
  • fromkeys(keys, default=None): 通过传入的关键字返回一个新的ParameterDict,default标识为所有keys设置的值
  • get(key, default=None):
  • items(): 迭代返回参数字典key、value对迭代器
  • keys(): 返回参数字典所有的关键字
  • pop(key): 删除参数“key”
  • popitem(): 删除并返回最后一次插入的参数
  • setdefault(key, default=None): 如果当前参数字典有关键字key,则返回key对应的参数;如果没有就插入key,它对应的参数是default的值,并返回default的值。
  • update(parameters): 根据parameters中的key-value对进行参数字典的更新,parameters的顺序不会改变参数字典的顺序,新加的参数默认追加。具体可以参考ModuleDict的对应方法。
  • values(): 返回所有参数对象

Top---Bottom

六、对于Module的全局钩子

  • register_module_forward_pre_hook
  • register_module_forward_hook
  • register_module_backward_hook
  • register_module_full_backward_hook
>>> def elfin_hook(module, in_data):
...     print(list(module.named_children()))
>>> nn.modules.module.register_module_forward_pre_hook(elfin_hook)
<torch.utils.hooks.RemovableHandle at 0x7f43743af1d0>

这里我们注册钩子的方式和环境变量的声明是类似的!需要进行如上的操作,在模型执行过程中,钩子自然会被调用!


Top---Bottom

完!

posted @ 2022-04-08 11:43  巴蜀秀才  阅读(92)  评论(0编辑  收藏  举报