【pytorch】土堆pytorch教程学习(六)神经网络的基本骨架——nn.module的使用

torch.nn 是 pytorch 的一个神经网络库(nn 是 neural network 的简称)。

Containers

torch.nn 构建神经网络的模型容器(Containers,骨架)有以下六个:

  • Module
  • Sequential
  • ModuleList
  • ModuleDict
  • ParameterList
  • ParameterDict

本文将介绍神经网络的基本骨架——nn.module的使用。

Module

所有神经网络模块的基类。自定义的模型也应该继承该类。

创建模型有两个要素:构建子模块拼接子模块。构建子模块包括构建卷积层、池化层、全连接层等。拼接子模块即按照一定的顺序把构建好的子模块拼接起来。

自定义模型继承该类要重写 __init__()forward()

  • __init__() 里构建子模块,将子模块作为当前模块类的常规属性。一般将网络中具有可学习参数的层放在__init__中。
  • forward() 前向传播函数,拼接子模块。
# 官方案例
import torch.nn as nn
import torch.nn.functional as F

# 自定义模型
class Model(nn.Module):
    def __init__(self):
        super().__init__() # 在对子类进行赋值之前,必须对父类进行__init__调用。
        # 构建子模块
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    # 前向传播函数
    def forward(self, x):
        # 拼接子模块
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))
        
# 模型调用
x = torch.randn(3, 1, 10, 20)
model = Model()
y = model(x)

为什么 forward() 方法能在model(x)时自动调用?
在 python 中当一个类定义了 __call__方法,则这个类实例就成为了可调用对象。而nn.Module 中的 __call__ 方法中调用了 forward() 方法,因此继承了 nn.Module 的子类对象就可以通过 model(x) 来调用 forward() 函数。


只要在 nn.Module 的子类中定义了 forward 函数,backward 函数就会被自动实现(利用Autograd)。

总结:自定义网络模型需要继承 nn.Module,并实现 __init__forward 函数。一个 Module 里可包含多个子 Module,比如 LeNet 是一个 Module,里面包括多个卷积层、池化层、全连接层等子 module。

posted @   hzyuan  阅读(97)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· .NET10 - 预览版1新功能体验(一)

喜欢请打赏

扫描二维码打赏

支付宝打赏

点击右上角即可分享
微信分享提示