从头学pytorch(十一):自定义层

自定义layer

https://www.cnblogs.com/sdu20112013/p/12132786.html一文里说了怎么写自定义的模型.本篇说怎么自定义层.
分两种:

  • 不含模型参数的layer
  • 含模型参数的layer

核心都一样,自定义一个继承自nn.Module的类,在类的forward函数里实现该layer的计算,不同的是,带参数的layer需要用到nn.Parameter

不含模型参数的layer

直接继承nn.Module

import torch
from torch import nn

class CenteredLayer(nn.Module):
    def __init__(self, **kwargs):
        super(CenteredLayer, self).__init__(**kwargs)
    def forward(self, x):
        return x - x.mean()

layer = CenteredLayer()
layer(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float))

net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())
y = net(torch.rand(4, 8))
y.mean().item()

含模型参数的layer

  • Parameter
  • ParameterList
  • ParameterDict

Parameter类其实是Tensor的子类,如果一个TensorParameter,那么它会自动被添加到模型的参数列表里。所以在自定义含模型参数的层时,我们应该将参数定义成Parameter,除了直接定义成Parameter类外,还可以使用ParameterListParameterDict分别定义参数的列表和字典。

ParameterList用法和list类似

class MyDense(nn.Module):
    def __init__(self):
        super(MyDense,self).__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn(4,4)) for i in range(4)])
        self.params.append(nn.Parameter(torch.randn(4,1)))

    def forward(self,x):
        for i in range(len(self.params)):
            x = torch.mm(x,self.params[i])
        return x

net = MyDense()
print(net)

输出

MyDense(
  (params): ParameterList(
      (0): Parameter containing: [torch.FloatTensor of size 4x4]
      (1): Parameter containing: [torch.FloatTensor of size 4x4]
      (2): Parameter containing: [torch.FloatTensor of size 4x4]
      (3): Parameter containing: [torch.FloatTensor of size 4x4]
      (4): Parameter containing: [torch.FloatTensor of size 4x1]
  )
)

ParameterDict用法和python dict类似.也可以用.keys(),.items()

class MyDictDense(nn.Module):
    def __init__(self):
        super(MyDictDense, self).__init__()
        self.params = nn.ParameterDict({
                'linear1': nn.Parameter(torch.randn(4, 4)),
                'linear2': nn.Parameter(torch.randn(4, 1))
        })
        self.params.update({'linear3': nn.Parameter(torch.randn(4, 2))}) # 新增

    def forward(self, x, choice='linear1'):
        return torch.mm(x, self.params[choice])

net = MyDictDense()
print(net)

print(net.params.keys(),net.params.items())

x = torch.ones(1, 4)
net(x, 'linear1')

输出

MyDictDense(
  (params): ParameterDict(
      (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
      (linear2): Parameter containing: [torch.FloatTensor of size 4x1]
      (linear3): Parameter containing: [torch.FloatTensor of size 4x2]
  )
)
odict_keys(['linear1', 'linear2', 'linear3']) odict_items([('linear1', Parameter containing:
tensor([[-0.2275, -1.0434, -1.6733, -1.8101],
        [ 1.7530,  0.0729, -0.2314, -1.9430],
        [-0.1399,  0.7093, -0.4628, -0.2244],
        [-1.6363,  1.2004,  1.4415, -0.1364]], requires_grad=True)), ('linear2', Parameter containing:
tensor([[ 0.5035],
        [-0.0171],
        [-0.8580],
        [-1.1064]], requires_grad=True)), ('linear3', Parameter containing:
tensor([[-1.2078,  0.4364],
        [-0.8203,  1.7443],
        [-1.7759,  2.1744],
        [-0.8799, -0.1479]], requires_grad=True))])

使用自定义的layer构造模型

layer1 = MyDense()
layer2 = MyDictDense()

net = nn.Sequential(layer2,layer1)
print(net)
print(net(x))

输出

Sequential(
  (0): MyDictDense(
    (params): ParameterDict(
        (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
        (linear2): Parameter containing: [torch.FloatTensor of size 4x1]
        (linear3): Parameter containing: [torch.FloatTensor of size 4x2]
    )
  )
  (1): MyDense(
    (params): ParameterList(
        (0): Parameter containing: [torch.FloatTensor of size 4x4]
        (1): Parameter containing: [torch.FloatTensor of size 4x4]
        (2): Parameter containing: [torch.FloatTensor of size 4x4]
        (3): Parameter containing: [torch.FloatTensor of size 4x4]
        (4): Parameter containing: [torch.FloatTensor of size 4x1]
    )
  )
)
tensor([[-4.7566]], grad_fn=<MmBackward>)
posted @   core!  阅读(3720)  评论(0编辑  收藏  举报
编辑推荐:
· 一个超经典 WinForm,WPF 卡死问题的终极反思
· ASP.NET Core - 日志记录系统(二)
· .NET 依赖注入中的 Captive Dependency
· .NET Core 对象分配(Alloc)底层原理浅谈
· 聊一聊 C#异步 任务延续的三种底层玩法
阅读排行:
· 互联网不景气了那就玩玩嵌入式吧,用纯.NET开发并制作一个智能桌面机器人(一):从.NET IoT入
· .NET 开发的分流抢票软件,不做广告、不收集隐私
· 一个超经典 WinForm,WPF 卡死问题的终极反思
· 开箱你的 AI 语音女友「GitHub 热点速览」
· 前端实现 HTML 网页转 PDF 并导出
点击右上角即可分享
微信分享提示