从零搭建Pytorch模型教程(二)搭建网络
前言
上一篇《
本文介绍了如何搭建神经网络,构建网络的几种方式,前向传播的过程,几种初始化方式,如何加载预训练模型的指定层等内容。本文以CNN为例,下一篇介绍如何搭建Transformer网络。
本文来自公众号CV技术指南的
欢迎关注公众号
搭建CNN网络
首先来看一个CNN网络 (以YOLO_v1的一部分层为例)。
搭建网络有几个要点:
-
自定义类要继承torch.nn.Module。有时候自己设计了一些模块,为了使用更方便,通常额外定义一个类,就像这里的Flatten,自定义的类也要继承torch.nn.Module。
-
完成init函数和forward函数。其中__init__函数完成网络的搭建,forward函数完成网络的前传路径。
-
完成所有层的参数初始化,一般只有卷积层,归一化层,全连接层要初始化,池化层没有参数。
__init__函数
构建网络层有几种方式,一种是pytorch官方已经有了定义的网络,如resnet,vgg,Inception等。一种是自定义层,例如自己设计了一个新的模块。
首先是使用pytorch官方库已经支持的网络,这些网络放在了torchvision.models中,下面选择自己需要的一个。
以下只列举了2D 模型的一部分,还有视频类的3D 模型。
若需要加载该网络在ImageNet上预训练的模型,则在括号内设置参数pretrained=True即可。但这种方式有个不好的问题在于这些预训练模型并不是在本地,因此每次运行都会从网上读取加载模型,非常浪费时间。因此,可以去它官网(https://pytorch.org/)上把那个模型下载到本地,通过下面指令完成加载。
另一种自定义层的,一般可以通过torch.nn.Sequential()来构建,在中间插入卷积层、归一化层、激活函数层、池化层即可。
例如下方这种是最常用的。
当网络很深时,上面这种方式构建比较麻烦,例如resnet,总不可能就按找上面这种方式这么写50层。就把它们共同的部分给构建出来,然后通过传参来设置不同的层。
例如:
1.下面这里先构建一个基本的几层作为一个类,每一层的参数(不同输入输出通道数,卷积核大小,有无池化)都通过传参来设置。
2.下面是设置不同的层。注:上面和下面都不是一个完整的代码,只是用来说明这种很多层的构建方式。
forward函数
这里就是网络的传播路径了,一般就是一路往下传就是。return的内容就是网络的输出。
如果想将中间某几层的输出拿出来,做一下特征金字塔,可以像下面这么写。
像可视化特征图里,想要可视化某一层的特征图,就可以像下面这么写。
初始化网络
初始化网络是要放在init函数里完成,分为两类,一类是随机初始化,一类是加载预训练模型。
随机初始化
关于随机初始化,目前主要有多种方式:Normal Initialization, Uniform Initialization,Xavier Initialization,He Initialization (也称 kaiming Initialization),LeCun Initialization。
关于这些初始化方法,可以看这篇文章《
下面是一种方式,直接按自定义的方式初始化。
也可以选择pytorch实现了的初始化。
还可以像下面这么写:
反正随便选择一种就好。
加载预训练模型初始化
加载预训练模型一般是在train文件里写,但有些网络由于是使用现成的backbone网络,例如使用了resnet50,然后后面加了自定义的模块,所以它想要resnet50预训练模型初始化backbone,而其它层做随机初始化,那加载预训练模型就是在网络定义中做的。因此,既然这里提到了初始化,就干脆写在这里。
最简单的就是直接整个模型都加载。
但也有一些情况下,我只想加载其中一部分层的参数。剩下一部分由于已经改变参数了,无法加载预训练模型,所以要选择上面的随机初始化。
这里有必要来说明网络的每一层是如何表示的。下面以一个例子来说明。
这里简单定义了一个网络。在最后面有这两行:
这两行的输出就是打印网络层的名字,实际上加载预训练模型时,也是按照这个名字来加载的。下面是一部分输出。
在预训练模型中就是这样,key即为网络层的名字,value即为它们对应的参数。因此,加载预训练模型可以按照下面这种方式加载。
自己定义的一些层是不会出现在pretrained_dict中,因此会将其剔除,从而只加载了pretrained_dict中有的层。
本文介绍了如何搭建神经网络,构建网络的几种方式,前向传播的过程,几种初始化方式,如何加载预训练模型的指定层等内容。
下一篇我们将介绍如何写train函数,以及包括设置优化方式,设置学习率,不同层设置不同学习率,解析参数等。
欢迎关注公众号
CV技术指南创建了一个交流氛围很不错的群,除了太偏僻的问题,几乎有问必答。关注公众号添加编辑的微信号可邀请加交流群。
在公众号中回复关键字 “入门指南“可获取计算机视觉入门所有必备资料。
其它文章
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· AI编程工具终极对决:字节Trae VS Cursor,谁才是开发者新宠?
· 开源Multi-agent AI智能体框架aevatar.ai,欢迎大家贡献代码
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!