【501】pytorch教程之nn.Module类详解
参考:pytorch教程之nn.Module类详解——使用Module类来自定义模型
pytorch中对于一般的序列模型,直接使用torch.nn.Sequential类及可以实现,这点类似于keras,但是更多的时候面对复杂的模型,比如:多输入多输出、多分支模型、跨层连接模型、带有自定义层的模型等,就需要自己来定义一个模型了。本文将详细说明如何让使用Mudule类来自定义一个模型。
pytorch里面一切自定义操作基本上都是继承nn.Module类来实现的。
我们在定义自已的网络的时候,需要继承nn.Module类,并重新实现构造函数__init__构造函数和forward这两个方法。但有一些注意技巧:
- 一般把网络中具有可学习参数的层(如全连接层、卷积层等)放在构造函数__init__()中,当然我也可以吧不具有参数的层也放在里面;
- 一般把不具有可学习参数的层(如ReLU、dropout、BatchNormanation层)可放在构造函数中,也可不放在构造函数中,如果不放在构造函数__init__里面,则在forward方法里面可以使用nn.functional来代替
- forward方法是必须要重写的,它是实现模型的功能,实现各个层之间的连接关系的核心。
所有放在构造函数__init__里面的层的都是这个模型的“固有属性”。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__( self ): # 固定内容 super (Model, self ).__init__() # 定义相关的函数 self .conv1 = nn.Conv2d( 1 , 20 , 5 ) self .conv2 = nn.Conv2d( 20 , 20 , 5 ) def forward( self , x): # 构建模型结构,可以使用F函数内容,其他调用__init__里面的函数 x = F.relu( self .conv1(x)) # 返回最终的结果 return F.relu( self .conv2(x)) |
☀☀☀<< 举例 >>☀☀☀
代码一:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 | import torch N, D_in, H, D_out = 64 , 1000 , 100 , 10 torch.manual_seed( 1 ) x = torch.randn(N, D_in) y = torch.randn(N, D_out) #-----changed part-----# model = torch.nn.Sequential( torch.nn.Linear(D_in, H), torch.nn.ReLU(), torch.nn.Linear(H, D_out), ) #-----changed part-----# loss_fn = torch.nn.MSELoss(reduction = 'sum' ) learning_rate = 1e - 4 optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate) for t in range ( 500 ): y_pred = model(x) loss = loss_fn(y_pred, y) if t % 100 = = 99 : print (t, loss.item()) optimizer.zero_grad() loss.backward() optimizer.step() |
代码二:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | import torch N, D_in, H, D_out = 64 , 1000 , 100 , 10 torch.manual_seed( 1 ) x = torch.randn(N, D_in) y = torch.randn(N, D_out) #-----changed part-----# class Alex_nn(nn.Module): def __init__( self ): super (Alex_nn, self ).__init__() self .h1 = torch.nn.Linear(D_in, H) self .h1_relu = torch.nn.ReLU() self .output = torch.nn.Linear(H, D_out) def forward( self , x): h1 = self .h1(x) h1_relu = self .h1_relu(h1) output = self .output(h1_relu) return output model = Alex_nn() #-----changed part-----# loss_fn = torch.nn.MSELoss(reduction = 'sum' ) learning_rate = 1e - 4 optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate) for t in range ( 500 ): y_pred = model(x) loss = loss_fn(y_pred, y) if t % 100 = = 99 : print (t, loss.item()) optimizer.zero_grad() loss.backward() optimizer.step() |
代码三:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 | import torch N, D_in, H, D_out = 64 , 1000 , 100 , 10 torch.manual_seed( 1 ) x = torch.randn(N, D_in) y = torch.randn(N, D_out) #-----changed part-----# class Alex_nn(nn.Module): def __init__( self , D_in_, H_, D_out_): super (Alex_nn, self ).__init__() self .D_in = D_in_ self .H = H_ self .D_out = D_out_ self .h1 = torch.nn.Linear( self .D_in, self .H) self .h1_relu = torch.nn.ReLU() self .output = torch.nn.Linear( self .H, self .D_out) def forward( self , x): h1 = self .h1(x) h1_relu = self .h1_relu(h1) output = self .output(h1_relu) return output model = Alex_nn(D_in, H, D_out) #-----changed part-----# loss_fn = torch.nn.MSELoss(reduction = 'sum' ) learning_rate = 1e - 4 optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate) for t in range ( 500 ): y_pred = model(x) loss = loss_fn(y_pred, y) if t % 100 = = 99 : print (t, loss.item()) optimizer.zero_grad() loss.backward() optimizer.step() |
posted on 2020-12-06 14:25 McDelfino 阅读(3060) 评论(0) 编辑 收藏 举报
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· .NET10 - 预览版1新功能体验(一)
2017-12-06 【271】IDL-ENVI二次开发
2017-12-06 【270】IDL处理GeoTIFF数据
2016-12-06 【231】罗技优联接收器配对使用方法