alex_bn_lee

导航

< 2025年3月 >
23 24 25 26 27 28 1
2 3 4 5 6 7 8
9 10 11 12 13 14 15
16 17 18 19 20 21 22
23 24 25 26 27 28 29
30 31 1 2 3 4 5

统计

【501】pytorch教程之nn.Module类详解

参考:pytorch教程之nn.Module类详解——使用Module类来自定义模型

  pytorch中对于一般的序列模型,直接使用torch.nn.Sequential类及可以实现,这点类似于keras,但是更多的时候面对复杂的模型,比如:多输入多输出、多分支模型、跨层连接模型、带有自定义层的模型等,就需要自己来定义一个模型了。本文将详细说明如何让使用Mudule类来自定义一个模型。

  pytorch里面一切自定义操作基本上都是继承nn.Module类来实现的。

  我们在定义自已的网络的时候,需要继承nn.Module类,并重新实现构造函数__init__构造函数和forward这两个方法。但有一些注意技巧:

  • 一般把网络中具有可学习参数的层(如全连接层、卷积层等)放在构造函数__init__()中,当然我也可以吧不具有参数的层也放在里面;
  • 一般把不具有可学习参数的层(如ReLU、dropout、BatchNormanation层)可放在构造函数中,也可不放在构造函数中,如果不放在构造函数__init__里面,则在forward方法里面可以使用nn.functional来代替
  • forward方法是必须要重写的,它是实现模型的功能,实现各个层之间的连接关系的核心。

  所有放在构造函数__init__里面的层的都是这个模型的“固有属性”。

  官方例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch.nn as nn
import torch.nn.functional as F
 
class Model(nn.Module):
    def __init__(self):
        # 固定内容
        super(Model, self).__init__()
 
        # 定义相关的函数
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)
 
    def forward(self, x):
        # 构建模型结构,可以使用F函数内容,其他调用__init__里面的函数
        x = F.relu(self.conv1(x))
 
        # 返回最终的结果
        return F.relu(self.conv2(x))

☀☀☀<< 举例 >>☀☀☀

  代码一:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
 
N, D_in, H, D_out = 64, 1000, 100, 10
  
torch.manual_seed(1)
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)
  
#-----changed part-----#
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)
#-----changed part-----#
 
loss_fn = torch.nn.MSELoss(reduction='sum')
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(500):
    y_pred = model(x)
    loss = loss_fn(y_pred, y)
    if t % 100 == 99:
        print(t, loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  代码二:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
 
N, D_in, H, D_out = 64, 1000, 100, 10
 
torch.manual_seed(1)
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)
 
#-----changed part-----#
class Alex_nn(nn.Module):
    def __init__(self):
        super(Alex_nn, self).__init__()
        self.h1 = torch.nn.Linear(D_in, H)
        self.h1_relu = torch.nn.ReLU()
        self.output = torch.nn.Linear(H, D_out)
         
    def forward(self, x):
        h1 = self.h1(x)
        h1_relu = self.h1_relu(h1)
        output = self.output(h1_relu)
        return output
         
model = Alex_nn()
#-----changed part-----#
 
loss_fn = torch.nn.MSELoss(reduction='sum')
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(500):
    y_pred = model(x)
    loss = loss_fn(y_pred, y)
    if t % 100 == 99:
        print(t, loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  代码三:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
 
N, D_in, H, D_out = 64, 1000, 100, 10
 
torch.manual_seed(1)
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)
 
#-----changed part-----#
class Alex_nn(nn.Module):
    def __init__(self, D_in_, H_, D_out_):
        super(Alex_nn, self).__init__()
        self.D_in = D_in_
        self.H = H_
        self.D_out = D_out_
         
        self.h1 = torch.nn.Linear(self.D_in, self.H)
        self.h1_relu = torch.nn.ReLU()
        self.output = torch.nn.Linear(self.H, self.D_out)
         
    def forward(self, x):
        h1 = self.h1(x)
        h1_relu = self.h1_relu(h1)
        output = self.output(h1_relu)
        return output
         
model = Alex_nn(D_in, H, D_out)
#-----changed part-----#
 
loss_fn = torch.nn.MSELoss(reduction='sum')
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(500):
    y_pred = model(x)
    loss = loss_fn(y_pred, y)
    if t % 100 == 99:
        print(t, loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  

 

 

 

 

 

posted on   McDelfino  阅读(3060)  评论(0编辑  收藏  举报

编辑推荐:
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
阅读排行:
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· .NET10 - 预览版1新功能体验(一)
历史上的今天:
2017-12-06 【271】IDL-ENVI二次开发
2017-12-06 【270】IDL处理GeoTIFF数据
2016-12-06 【231】罗技优联接收器配对使用方法
点击右上角即可分享
微信分享提示