7. pytorch 现有网络模型的使用与修改和模型的保存与加载
正文
PyTorch是一个开源的Python机器学习库,基于Torch,用于自然语言处理等应用程序。他提供了大量的模型供我们所使用,如下图所示:
下面,我们选择其中一个网络进行使用,介绍如何使用、并修改 pytorch 本身为我们提供的现有网络。最后介绍一下模型的保存和修改。
pytorch 现有网络的使用与修改
下面我们以 VGG(Very Deep Convolutional Networks for Large-Scale Image Recognition)的使用为例,进行介绍该网络。
VGG 16 简介
VGG16网络是14年牛津大学计算机视觉组和Google DeepMind公司研究员一起研发的深度网络模型。该网络一共有16个训练参数的网络,该网络的具体网络结构如下所示:
不难看出,该网络主要用于对 224 x 224 的图像进行 1000 分类。下面我们查看 VGG 在 pytorch 上的官方文档。
VGG 16 doc
从帮助文档中,我们可以清楚的看到 pytorch 为我们提供了各种版本的 VGG,我们选择 VGG 16 进行查看。
VGG16 的简单使用
从 vgg 16的帮助文档可以得知,该模型训练的数据是 ImageNet
,我们进入 torchvision.datasets 查看 ImageNet
但是该数据集实在是太大了,根本下不了,还是不搞了。建立一个该网络的模型查看参数: ```python import torch import torchvision import torch.nn as nn # import torchvision.models
vgg_model_pretrained = torchvision.models.vgg16(pretrained=True, progress=True)
vgg_model_original = torchvision.models.vgg16(pretrained=False, progress=True)
print(vgg_model_original)
print(vgg_model_pretrained)
vgg_model_pretrained.add_module()
<p align="center"> <img src="https://img2020.cnblogs.com/blog/1772262/202111/1772262-20211113091424359-823205952.png" style="zoom:100%"/> </p> <br/> <p align="center"> <img src="https://img2020.cnblogs.com/blog/1772262/202111/1772262-20211113091441147-673036772.png" style="zoom:100%"/> </p> <br/> 仔细查看这个网络的组成,你可以发现,组成该网络的一个个小 module 就是我们之前所介绍过的`Conv2d`, `ReLU`, `MaxPool2d`, `Linear`, `Dropout` 等等函数, ### VGG16 模型修改 经过上面的代码,我们可以较为轻松的看到 VGG16 神经网络的结构框架,那么我们如何修改别人已经写好的模型呢? 想要修改别人写好的模型,主要有一下这几种操作 <p align="center"> <img src="https://img2020.cnblogs.com/blog/1772262/202111/1772262-20211113092149496-1776250931.png" style="zoom:100%"/> </p> <br/> 选中模型,进行 add_module() 或者是直接对模型进行修改 <p align="center"> <img src="https://img2020.cnblogs.com/blog/1772262/202111/1772262-20211113094536565-291308585.png" style="zoom:100%"/> </p> <br/> <p align="center"> <img src="https://img2020.cnblogs.com/blog/1772262/202111/1772262-20211113094752626-1549811574.png" style="zoom:100%"/> </p> <br/> ```python import torch import torchvision import torch.nn as nn # import torchvision.models vgg_model_pretrained = torchvision.models.vgg16(pretrained=True, progress=True) vgg_model_original = torchvision.models.vgg16(pretrained=False, progress=True) print(vgg_model_original) print(vgg_model_pretrained) # vgg_model_pretrained.add_module() vgg_model_original.classifier.add_module('15', nn.Linear(in_features=1000, out_features=10, bias=True)) print(vgg_model_original) vgg_model_original.classifier[7] = nn.Linear(in_features=1000, out_features=15, bias=True) print(vgg_model_original)
根据上诉代码,我们就将 1000 分类问题的网络修改成了 10 分类或者是 15 分类问题的网络了。
模型的保存和加载
当我们利用数据将模型训练好之后,往往需要保存模型。同时,当我们创建模型的时候,也可能需要加载我们之前已经训练好的参数,下面我来介绍一下操作方法。
保留模型结构和模型参数
通过 torch.save() 和 torch.load() 进行保存模型和参数
import torch import torchvision import torch.nn as nn # import torchvision.models vgg_model_pretrained = torchvision.models.vgg16(pretrained=True, progress=True) torch.save(vgg_model_pretrained, "../../models_param/vgg_model_pretrained.pth") vgg_model_load = torch.load(f="../../models_param/vgg_model_pretrained.pth") print(111)
打一个断点,查看保存模型和加载模型的参数情况
仅保留模型参数
同样是使用 save 和 load 参数,但是用法有所不同,他所保存的是一个模型参数,以字典dict 的形式保存
import torch import torchvision import torch.nn as nn # import torchvision.models vgg_model_pretrained = torchvision.models.vgg16(pretrained=True, progress=True) torch.save(vgg_model_pretrained.state_dict(), "../../models_param/vgg_model_pretrained_method2.pth") vgg_model_load_method2 = torchvision.models.vgg16() vgg_model_load_method2.load_state_dict(torch.load("../../models_param/vgg_model_pretrained_method2.pth")) print("this is a breakpoint!")
断点查看 save 和 load 模型的参数情况
一模一样,没有问题。
Date:2021/11/13
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· winform 绘制太阳,地球,月球 运作规律
· AI与.NET技术实操系列(五):向量存储与相似性搜索在 .NET 中的实现
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)