Pytorch——Pytorch Lightning框架使用手册
本文主要是记录下,使用PytorchLightning这个如何进行深度学习的训练,记录一下本人平常使用这个框架所需要注意的地方,由于框架的理解深入本文会时不时进行更新(第三部分的常见问题会是不是的更新走的),本文深度参考以下两个网站pytorch_lightning 全程笔记 、Pytorch Lightning 完全攻略如果大家觉得本文写得不是很清楚,大家可以进一步看看这两篇文章。
一、框架使用方案
正如网络上大家介绍的那样,PL框架可以让人专心在模型内部的研究。我们在复杂的项目中,可能会出现多个模型,并且模型多个模型之间存在着许多的联系,如果在项目中想要更换某些模型model,会导致重写很多代码。但是如果采用PL框架,那么这将会是一件比较容易的事情。根据Pytorch Lightning 完全攻略这篇文章的推荐,我建议采用以下的代码风格:
root-
|-dataModule
|-__init__.py
|-data_interface.py
|-xxxdataset1.py
|-xxxdataset2.py
|-...
|-modelModule
|-__init__.py
|-model_interface.py
|-xxxmodel1.py
|-xxxmodel2.py
|-...
|-train.py
其中把dataModule和modelModule写成python包,这两个包的__init__.py分别是:
- from .data_interface import DInterface
- from .model_interface import MInterface
在DInterface和MInterface分别是data_interface.py和model_interface.py中创建的类,他们两个分别就是
- class DInterface(pl.LightningDataModule): 用于所有数据集的接口,在setup()方法中初始化你准备好的xxxdataset1.py,xxxdataset2.py中定义的torch.utils.data.Dataset类。在train_dataloader,val_dataloader,test_dataloader这几个方法中载入Dataloader即可。
- class MInterface(pl.LightningModule): 用作模型的接口,在
__init__()
函数中import你准备好的xxxmodel2.py,xxxmodel1.py这些模型。重写training_step方法,validation_step方法,configure_optimizers方法。
当大家在更改模型的时候只需要在对应的模块上进行更改即可,最后train.py主要功能就是读取参数,和调用dataModule和modelModule这两个包进行实例化DInterface和MInterface,当然一些PL框架的回调函数也需要在train.py里进行定义。
二、框架基本模块(Module)
2.1 LightningModule
LightningModule必须包含的部分是init()和training_step(self, batch, batch_idx),其中init()主要是进行模型的初始化和定义(不需要定义数据集等)。training_step(...)主要是进行定义每个batch数据的处理函数。
1 2 3 4 5 | def training_step( self , batch, batch_idx): x, y, z = batch out = self .encoder(x) loss = self .loss(out, x) return loss |
这里的batch就是从dataloader中出来的一个batch的数据,类似于for batch in dataloader :。可以看到这个函数的返回值是一个loss,文档中有提到这个方法一定要返回一个loss,如果是返回一个字典,那么必须包含“loss”这个键。当然也有例外,当我们重写training_step_end()方法的时候就不用training_step必须返回一个loss了,此时可以返回任意的东西,但是要注意training_step_end()方法就必须返回一个loss,至于training_step_end()主要是在使用多GPU训练的时候需要重写该方法,主要是进行损失的汇总。
除了training_step,我们还有validation_step,test_step,其中test_step不会在训练中调用,而validation_step则是对测试数据进行模型推理,一般在这个步骤里可以用self.log进行记录某些值,例如:
1 2 3 4 5 | def validation_step( self , batch, batch_idx): pre = model(batch) loss = self .lossfun(...) # log记录 self .log( 'val_loss' ,loss, on_epoch = True , prog_bar = True , logger = True ) |
上面的使用的self.log是非常重要的一个方法,这个方法继承自LightningModule这个父类,我们使用这里log就可以在训练时使用ModelCheckpoint对象(用于保存模型的参数对象)去检测测试步骤中的参数(比如这里我们就可以检测val_loss这个值,来确定是否保存这个模型参数)
self.log()中常用参数以下:
- prog_bar:如果是True,该值将会显示在进度条上
- logger:如果是True,将会记录到logger器中(会显示在tensorboard上)
2.2 LightningDataModule
这一个类必须包含的部分是setup(self, stage=None)方法,train_dataloader()方法。
- setup(self, stage=None):主要是进行Dataset的实例化,包括但不限于进行数据集的划分,划分成训练集和测试集,一般来说都是Dataset类
- train_dataloader():很简单,只需要返回一个DataLoader类即可。
-
val_dataloader():与train_dataloader一样,用于初始化DataLoader对象,并返回
有些时候也会定义collate_fn函数,在DataLoader创建时传入collate_fn参数,用于对Dataset进处理(但实际上一般是在Dataset类中定义,根据Dataset的属性类个性化配置)。
三、常见问题
参考网站:
pytorch_lightning 全程笔记 - 知乎 (zhihu.com)
Pytorch Lightning 完全攻略 - 知乎 (zhihu.com)
PyTorch Lightning初步教程(上) - 知乎 (zhihu.com)
“简约版”Pytorch —— Pytorch-Lightning详解_@YangZai的博客-CSDN博客
201024-5步PyTorchLightning中设置并访问tensorboard_专注机器学习之路-CSDN博客
PyTorch Ligntning】快速上手简明指南_闻韶-CSDN博客
PyTorch Lightning 工具学习 - 知乎 (zhihu.com)
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· SQL Server 2025 AI相关能力初探
· AI编程工具终极对决:字节Trae VS Cursor,谁才是开发者新宠?
· 开源Multi-agent AI智能体框架aevatar.ai,欢迎大家贡献代码
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南