自定义层
一、前言
深度学习成功的背后一个因素是可以用创造性的方式组合广泛的层,从而设计出适合于各种任务的结构
二、不带参数的层
1、要构建一个层,我们只需继承基础层类并实现正向传播功能
1 2 3 4 5 6 7 8 9 10 11 12 13 | # 构造不带参数的层 # 下面的CenteredLayer类要从其输入中减去均值。 import torch import torch.nn.functional as F from torch import nn class CenteredLayer(nn.Module): def __init__(self): super().__init__() # 从其输入减去均值 def forward(self, X): return X - X.mean() |
2、测试是否按预期工作
1 2 3 4 5 6 | layer = CenteredLayer() layer(torch.FloatTensor([1, 2, 3, 4, 5])) #输出结果 tensor([-2., -1., 0., 1., 2.]) |
3、将层作为组件合并构建到复杂模型中
1 2 | # 将层作为组件合并构建到复杂模型中 net = nn.Sequential(nn.Linear(8, 128), CenteredLayer()) |
三、带参数的层
1、既然我们知道了如何定义简单的层,接下来继续定义具有参数的层,这些参数可以通过训练进行调整
2、我们可以使用内置函数来创建参数,这些函数提供一些基本的管理功能
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | # 定义带有参数的层 class MyLinear(nn.Module): # in_units:输入数量 # units:输出数量 def __init__(self, in_units, units): super().__init__() '' ' 首先可以把这个函数理解为类型转换函数 将一个不可训练的类型Tensor转换成可以训练的类型parameter 并将parameter绑定到这个module里面 '' ' self.weight = nn.Parameter(torch.randn(in_units, units)) self.bias = nn.Parameter(torch.randn(units,)) def forward(self, X): linear = torch.matmul(X, self.weight.data) + self.bias.data return F.relu(linear) |
3、实例化MyLinear类并访问其模型参数
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | linear = MyLinear(5, 3) print(linear.bias) linear.weight #输出结果 Parameter containing: tensor([ 2.0313, -0.5231, 2.2049], requires_grad=True) Parameter containing: tensor([[-0.5891, -0.0976, -2.2352], [-1.3207, -0.3231, 0.1074], [ 0.8634, -0.6129, -0.4620], [ 0.4784, -0.1825, 0.4654], [-0.7650, -0.5062, -0.8821]], requires_grad=True) |
4、使用自定义层直接执行正向传播计算
1 2 3 4 5 6 | linear(torch.rand(2, 5)) #输出结果 tensor([[2.0072, 0.0000, 0.7855], [1.4918, 0.0000, 1.8945]]) |
5、使用自定义层构建模型,可以像使用内置的全连接层一样使用自定义层
1 2 3 4 5 6 7 | net = nn.Sequential(MyLinear(64, 8), MyLinear(8, 1)) net(torch.rand(2, 64)) #输出结果 tensor([[0.], [0.]]) |
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· winform 绘制太阳,地球,月球 运作规律
· AI与.NET技术实操系列(五):向量存储与相似性搜索在 .NET 中的实现
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)